From dcfa8110e8763a8c0f010089ed0de2dc42c0df1e Mon Sep 17 00:00:00 2001 From: lwwang1995 Date: Mon, 7 Dec 2020 23:07:33 +0800 Subject: [PATCH] Fix bugs for Gats model. --- qlib/contrib/model/pytorch_gats_ts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 76b94aa21..7706c4d28 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel class DailyBatchSampler(Sampler): - + def __init__(self, data_source): self.data_source = data_source self.data = self.data_source.data.loc[self.data_source.get_index()] - self.daily_count = self.data.groupby(level=0).size().values[1:] - self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:] + self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch + self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch + self.daily_index[0] = 0 def __iter__(self): + for idx, count in zip(self.daily_index, self.daily_count): yield np.arange(idx, idx + count)