mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
An example to get index from TSDataSampler (#679)
This commit is contained in:
@@ -75,6 +75,35 @@ class TestDataset(TestAutoData):
|
||||
equal = np.isclose(data_from_df, data_from_ds)
|
||||
self.assertTrue(equal[~np.isnan(data_from_df)].all())
|
||||
|
||||
if False:
|
||||
# 3) get both index and data
|
||||
# NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
class IdxSampler:
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.sampler[i], i
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
i = len(tsds) - 1
|
||||
idx = tsds.get_index()
|
||||
tsds[i]
|
||||
idx[i]
|
||||
|
||||
s_w_i = IdxSampler(tsds)
|
||||
test_loader = DataLoader(s_w_i)
|
||||
|
||||
s_w_i[3]
|
||||
for data, i in test_loader:
|
||||
break
|
||||
print(data.shape)
|
||||
print(idx[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
Reference in New Issue
Block a user