diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 76b94aa21..7706c4d28 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel class DailyBatchSampler(Sampler): - + def __init__(self, data_source): self.data_source = data_source self.data = self.data_source.data.loc[self.data_source.get_index()] - self.daily_count = self.data.groupby(level=0).size().values[1:] - self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:] + self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch + self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch + self.daily_index[0] = 0 def __iter__(self): + for idx, count in zip(self.daily_index, self.daily_count): yield np.arange(idx, idx + count)