mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update
This commit is contained in:
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
"freq_dim" : 15,
|
||||
"dropout_W": 0.5,
|
||||
"dropout_U": 0.5,
|
||||
"n_epochs": 200,
|
||||
"n_epochs": 10,
|
||||
"lr": 1e-3,
|
||||
"batch_size": 800,
|
||||
"early_stop": 20,
|
||||
|
||||
@@ -420,11 +420,11 @@ class SFM(Model):
|
||||
index = x_test.index
|
||||
x_test = torch.from_numpy(x_test.values).float()
|
||||
|
||||
x_test = x_test.to(device)
|
||||
x_test = x_test.to(self.device)
|
||||
self.sfm_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if device != 'cpu':
|
||||
if self.device != 'cpu':
|
||||
preds = self.sfm_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.sfm_model(x_test).detach().numpy()
|
||||
|
||||
Reference in New Issue
Block a user