diff --git a/examples/workflow_by_code_sfm.py b/examples/workflow_by_code_sfm.py index 6a72db3a1..45f34f012 100644 --- a/examples/workflow_by_code_sfm.py +++ b/examples/workflow_by_code_sfm.py @@ -66,7 +66,7 @@ if __name__ == "__main__": "freq_dim" : 15, "dropout_W": 0.5, "dropout_U": 0.5, - "n_epochs": 200, + "n_epochs": 10, "lr": 1e-3, "batch_size": 800, "early_stop": 20, diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 1d3012331..04a5ad33f 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -420,11 +420,11 @@ class SFM(Model): index = x_test.index x_test = torch.from_numpy(x_test.values).float() - x_test = x_test.to(device) + x_test = x_test.to(self.device) self.sfm_model.eval() with torch.no_grad(): - if device != 'cpu': + if self.device != 'cpu': preds = self.sfm_model(x_test).detach().cpu().numpy() else: preds = self.sfm_model(x_test).detach().numpy()