1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
This commit is contained in:
Alex Wang
2020-11-25 14:27:19 +08:00
parent 6b90c6d066
commit fcbafde741
2 changed files with 3 additions and 3 deletions

View File

@@ -66,7 +66,7 @@ if __name__ == "__main__":
"freq_dim" : 15,
"dropout_W": 0.5,
"dropout_U": 0.5,
"n_epochs": 200,
"n_epochs": 10,
"lr": 1e-3,
"batch_size": 800,
"early_stop": 20,

View File

@@ -420,11 +420,11 @@ class SFM(Model):
index = x_test.index
x_test = torch.from_numpy(x_test.values).float()
x_test = x_test.to(device)
x_test = x_test.to(self.device)
self.sfm_model.eval()
with torch.no_grad():
if device != 'cpu':
if self.device != 'cpu':
preds = self.sfm_model(x_test).detach().cpu().numpy()
else:
preds = self.sfm_model(x_test).detach().numpy()