1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00
Files
qlib/tests/test_dataset.py
2020-12-09 17:20:36 +08:00

102 lines
3.7 KiB
Python
Executable File

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
import numpy as np
from torch.utils.data import DataLoader
import time
from qlib.data.dataset.handler import DataHandlerLP
class TestDataset(TestAutoData):
def testTSDataset(self):
tsdh = TSDatasetH(
handler={
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
"infer_processors": [
{"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}},
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}},
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
"DropnaLabel",
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm
],
},
},
segments={
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
)
tsds_train = tsdh.prepare("train", data_key=DataHandlerLP.DK_L) # Test the correctness
tsds = tsdh.prepare("valid", data_key=DataHandlerLP.DK_L)
t = time.time()
for idx in np.random.randint(0, len(tsds_train), size=2000):
_ = tsds_train[idx]
print(f"2000 sample takes {time.time() - t}s")
t = time.time()
for _ in range(20):
data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)]
print(data.shape)
print(f"2000 sample(batch index) * 20 times takes {time.time() - t}s")
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
t = time.time()
for data in train_loader:
pass
print(f"Passing all training batches takes {time.time() - t}s")
# Here is an example of ffill+bfill for index
tsds_train.config(fillna_type="ffill+bfill")
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
t = time.time()
for data in train_loader:
pass
print(f"Passing all training batches with fill takes {time.time() - t}s")
# The dimension of sample is same as tabular data, but it will return timeseries data of the sample
# We have two method to get the time-series of a sample
# 1) sample by int index directly
tsds[len(tsds) - 1]
# 2) sample by <datetime,instrument> index
data_from_ds = tsds["2016-12-31", "SZ300315"]
# Check the data
# Get data from DataFrame Directly
data_from_df = (
tsdh._handler.fetch(data_key=DataHandlerLP.DK_L)
.loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"]
.iloc[-30:]
.values
)
equal = np.isclose(data_from_df, data_from_ds)
self.assertTrue(equal[~np.isnan(data_from_df)].all())
if __name__ == "__main__":
unittest.main(verbosity=10)
# User could use following code to run test when using line_profiler
# td = TestDataset()
# td.setUpClass()
# td.testTSDataset()