diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 7fbbd7c6e..e013238fa 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -340,7 +340,7 @@ class SFM(Model): def train_epoch(self, x_train, y_train): x_train_values = x_train.values - y_train_values = np.squeeze(y_train.values) * 100 + y_train_values = np.squeeze(y_train.values) self.sfm_model.train()