1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix bugs for Gats model.

This commit is contained in:
lwwang1995
2020-12-07 23:07:33 +08:00
committed by you-n-g
parent 666e1ffcbd
commit dcfa8110e8

View File

@@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values[1:]
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:]
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_index[0] = 0
def __iter__(self):
for idx, count in zip(self.daily_index, self.daily_count):
yield np.arange(idx, idx + count)