From cae4c9c9249fa2cb0bb64b6e3e5ac235794bbedf Mon Sep 17 00:00:00 2001 From: you-n-g Date: Wed, 10 Nov 2021 14:35:27 +0800 Subject: [PATCH] An example to get index from TSDataSampler (#679) --- tests/test_dataset.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ed2f14d2f..47950f6ae 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -75,6 +75,35 @@ class TestDataset(TestAutoData): equal = np.isclose(data_from_df, data_from_ds) self.assertTrue(equal[~np.isnan(data_from_df)].all()) + if False: + # 3) get both index and data + # NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example + from torch.utils.data import DataLoader + + class IdxSampler: + def __init__(self, sampler): + self.sampler = sampler + + def __getitem__(self, i: int): + return self.sampler[i], i + + def __len__(self): + return len(self.sampler) + + i = len(tsds) - 1 + idx = tsds.get_index() + tsds[i] + idx[i] + + s_w_i = IdxSampler(tsds) + test_loader = DataLoader(s_w_i) + + s_w_i[3] + for data, i in test_loader: + break + print(data.shape) + print(idx[i]) + if __name__ == "__main__": unittest.main(verbosity=10)