diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 15ee7ef71..868ab1513 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -297,7 +297,7 @@ class DNNModelPytorch(Model): _model_path = os.path.join(model_dir, _model_name) # Load model self.dnn_model.load_state_dict(torch.load(_model_path)) - self._fitted = True + self.fitted = True class AverageMeter: