diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 62e6096ca..969cca63d 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -201,7 +201,7 @@ class DNNModelPytorch(Model): y_val_auto = torch.from_numpy(y_valid.values).float().to(self.device) w_val_auto = torch.from_numpy(w_valid.values).float().to(self.device) - for step in range(self.max_steps): + for step in range(1, self.max_steps + 1): if stop_steps >= self.early_stop_rounds: if verbose: self.logger.info("\tearly stop") @@ -225,7 +225,7 @@ class DNNModelPytorch(Model): # validation train_loss += loss.val # for evert `eval_steps` steps or at the last steps, we will evaluate the model. - if step % self.eval_steps == 0 or step + 1 == self.max_steps: + if step % self.eval_steps == 0 or step == self.max_steps: stop_steps += 1 train_loss /= self.eval_steps