1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Fix bug in MLP

This commit is contained in:
D-X-Y
2021-03-11 19:15:18 -08:00
parent d38b8d6001
commit 593553f573

View File

@@ -158,7 +158,7 @@ class DNNModelPytorch(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def fit(
self,
@@ -222,7 +222,8 @@ class DNNModelPytorch(Model):
# validation
train_loss += loss.val
if step and step % self.eval_steps == 0:
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
stop_steps += 1
train_loss /= self.eval_steps
@@ -255,7 +256,7 @@ class DNNModelPytorch(Model):
# update learning rate
self.scheduler.step(cur_loss_val)
# restore the optimal parameters after training ??
# restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path))
if self.use_gpu:
torch.cuda.empty_cache()