1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

An example to get index from TSDataSampler (#679)

This commit is contained in:
you-n-g
2021-11-10 14:35:27 +08:00
committed by GitHub
parent a2be6e28e9
commit cae4c9c924

View File

@@ -75,6 +75,35 @@ class TestDataset(TestAutoData):
equal = np.isclose(data_from_df, data_from_ds)
self.assertTrue(equal[~np.isnan(data_from_df)].all())
if False:
# 3) get both index and data
# NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example
from torch.utils.data import DataLoader
class IdxSampler:
def __init__(self, sampler):
self.sampler = sampler
def __getitem__(self, i: int):
return self.sampler[i], i
def __len__(self):
return len(self.sampler)
i = len(tsds) - 1
idx = tsds.get_index()
tsds[i]
idx[i]
s_w_i = IdxSampler(tsds)
test_loader = DataLoader(s_w_i)
s_w_i[3]
for data, i in test_loader:
break
print(data.shape)
print(idx[i])
if __name__ == "__main__":
unittest.main(verbosity=10)