From 6b053137fde8d622e69120abbd2663f8e9bad66f Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 26 Nov 2020 12:42:20 +0000 Subject: [PATCH] fix format --- qlib/contrib/model/pytorch_gats.py | 2 +- qlib/contrib/model/pytorch_hats.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index f6bd427f5..5d2dbd9a4 100755 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -169,7 +169,7 @@ class GAT(Model): daily_shuffle = list(zip(daily_index, daily_count)) np.random.shuffle(daily_shuffle) daily_index, daily_count = zip(*daily_shuffle) - return daily_index, daily_count + return daily_index, daily_count def train_epoch(self, x_train, y_train): diff --git a/qlib/contrib/model/pytorch_hats.py b/qlib/contrib/model/pytorch_hats.py index 3ba2d676e..bdb68be28 100644 --- a/qlib/contrib/model/pytorch_hats.py +++ b/qlib/contrib/model/pytorch_hats.py @@ -175,7 +175,7 @@ class HATS(Model): daily_shuffle = list(zip(daily_index, daily_count)) np.random.shuffle(daily_shuffle) daily_index, daily_count = zip(*daily_shuffle) - return daily_index, daily_count + return daily_index, daily_count def train_epoch(self, x_train, y_train):