diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index cf33732b9..1d3012331 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -398,7 +398,7 @@ class SFM(Model): # update learning rate self.scheduler.step(cur_loss_val) - if device != 'cpu': + if self.device != 'cpu': torch.cuda.empty_cache() def get_loss(self, pred, target, loss_type):