diff --git a/qlib/contrib/model/pytorch_hats.py b/qlib/contrib/model/pytorch_hats.py index 1eff35203..a0da88dbf 100644 --- a/qlib/contrib/model/pytorch_hats.py +++ b/qlib/contrib/model/pytorch_hats.py @@ -180,7 +180,7 @@ class HATS(Model): def train_epoch(self, x_train, y_train): x_train_values = x_train.values - y_train_values = np.squeeze(y_train.values) + y_train_values = np.squeeze(y_train.values) self.HATS_model.train()