mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix bugs for Gats model.
This commit is contained in:
@@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
|
||||
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.data.loc[self.data_source.get_index()]
|
||||
self.daily_count = self.data.groupby(level=0).size().values[1:]
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:]
|
||||
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
|
||||
self.daily_index[0] = 0
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
for idx, count in zip(self.daily_index, self.daily_count):
|
||||
yield np.arange(idx, idx + count)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user