1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

Fix pytorch_nn.py step bug (#864)

* Update pytorch_nn.py

* Update pytorch_nn.py
This commit is contained in:
you-n-g
2022-01-18 22:39:19 +08:00
committed by GitHub
parent 86e1265f69
commit bdf1fb29a6

View File

@@ -201,7 +201,7 @@ class DNNModelPytorch(Model):
y_val_auto = torch.from_numpy(y_valid.values).float().to(self.device)
w_val_auto = torch.from_numpy(w_valid.values).float().to(self.device)
for step in range(self.max_steps):
for step in range(1, self.max_steps + 1):
if stop_steps >= self.early_stop_rounds:
if verbose:
self.logger.info("\tearly stop")
@@ -225,7 +225,7 @@ class DNNModelPytorch(Model):
# validation
train_loss += loss.val
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
if step % self.eval_steps == 0 or step == self.max_steps:
stop_steps += 1
train_loss /= self.eval_steps