mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
Fix pytorch_nn.py step bug (#864)
* Update pytorch_nn.py * Update pytorch_nn.py
This commit is contained in:
@@ -201,7 +201,7 @@ class DNNModelPytorch(Model):
|
||||
y_val_auto = torch.from_numpy(y_valid.values).float().to(self.device)
|
||||
w_val_auto = torch.from_numpy(w_valid.values).float().to(self.device)
|
||||
|
||||
for step in range(self.max_steps):
|
||||
for step in range(1, self.max_steps + 1):
|
||||
if stop_steps >= self.early_stop_rounds:
|
||||
if verbose:
|
||||
self.logger.info("\tearly stop")
|
||||
@@ -225,7 +225,7 @@ class DNNModelPytorch(Model):
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
if step % self.eval_steps == 0 or step == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user