From 593553f57310afdce69e2f1cf93a8948126b111f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 11 Mar 2021 19:15:18 -0800 Subject: [PATCH] Fix bug in MLP --- qlib/contrib/model/pytorch_nn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index fd712848e..623ca5950 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -158,7 +158,7 @@ class DNNModelPytorch(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def fit( self, @@ -222,7 +222,8 @@ class DNNModelPytorch(Model): # validation train_loss += loss.val - if step and step % self.eval_steps == 0: + # for evert `eval_steps` steps or at the last steps, we will evaluate the model. + if step % self.eval_steps == 0 or step + 1 == self.max_steps: stop_steps += 1 train_loss /= self.eval_steps @@ -255,7 +256,7 @@ class DNNModelPytorch(Model): # update learning rate self.scheduler.step(cur_loss_val) - # restore the optimal parameters after training ?? + # restore the optimal parameters after training self.dnn_model.load_state_dict(torch.load(save_path)) if self.use_gpu: torch.cuda.empty_cache()