mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
support multi indexing of TSDatasetSample
This commit is contained in:
@@ -48,6 +48,12 @@ class TestDataset(TestAutoData):
|
||||
_ = tsds_train[idx]
|
||||
print(f"2000 sample takes {time.time() - t}s")
|
||||
|
||||
t = time.time()
|
||||
for _ in range(20):
|
||||
data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)]
|
||||
print(data.shape)
|
||||
print(f"2000 sample(batch index) * 20 times takes {time.time() - t}s")
|
||||
|
||||
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
|
||||
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
|
||||
t = time.time()
|
||||
@@ -88,3 +94,8 @@ class TestDataset(TestAutoData):
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
# User could use following code to run test when using line_profiler
|
||||
# td = TestDataset()
|
||||
# td.setUpClass()
|
||||
# td.testTSDataset()
|
||||
|
||||
Reference in New Issue
Block a user