diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index f6bd427f5..5d2dbd9a4 100755 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -169,7 +169,7 @@ class GAT(Model): daily_shuffle = list(zip(daily_index, daily_count)) np.random.shuffle(daily_shuffle) daily_index, daily_count = zip(*daily_shuffle) - return daily_index, daily_count + return daily_index, daily_count def train_epoch(self, x_train, y_train): diff --git a/qlib/contrib/model/pytorch_hats.py b/qlib/contrib/model/pytorch_hats.py index 3ba2d676e..bdb68be28 100644 --- a/qlib/contrib/model/pytorch_hats.py +++ b/qlib/contrib/model/pytorch_hats.py @@ -175,7 +175,7 @@ class HATS(Model): daily_shuffle = list(zip(daily_index, daily_count)) np.random.shuffle(daily_shuffle) daily_index, daily_count = zip(*daily_shuffle) - return daily_index, daily_count + return daily_index, daily_count def train_epoch(self, x_train, y_train):