From fcbafde741f8577284f36f5e3d7141d126a9b486 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 25 Nov 2020 14:27:19 +0800 Subject: [PATCH] update --- examples/workflow_by_code_sfm.py | 2 +- qlib/contrib/model/pytorch_sfm.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/workflow_by_code_sfm.py b/examples/workflow_by_code_sfm.py index 6a72db3a1..45f34f012 100644 --- a/examples/workflow_by_code_sfm.py +++ b/examples/workflow_by_code_sfm.py @@ -66,7 +66,7 @@ if __name__ == "__main__": "freq_dim" : 15, "dropout_W": 0.5, "dropout_U": 0.5, - "n_epochs": 200, + "n_epochs": 10, "lr": 1e-3, "batch_size": 800, "early_stop": 20, diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 1d3012331..04a5ad33f 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -420,11 +420,11 @@ class SFM(Model): index = x_test.index x_test = torch.from_numpy(x_test.values).float() - x_test = x_test.to(device) + x_test = x_test.to(self.device) self.sfm_model.eval() with torch.no_grad(): - if device != 'cpu': + if self.device != 'cpu': preds = self.sfm_model(x_test).detach().cpu().numpy() else: preds = self.sfm_model(x_test).detach().numpy()