mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Fix bug in MLP
This commit is contained in:
@@ -158,7 +158,7 @@ class DNNModelPytorch(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
@@ -222,7 +222,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
if step and step % self.eval_steps == 0:
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
@@ -255,7 +256,7 @@ class DNNModelPytorch(Model):
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training ??
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user