From 60f62482b71066cefd500ddcf9494858ad706ecd Mon Sep 17 00:00:00 2001 From: lwwang1995 Date: Sat, 5 Dec 2020 22:36:04 +0800 Subject: [PATCH] Update test_dataset --- tests/test_dataset.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) mode change 100644 => 100755 tests/test_dataset.py diff --git a/tests/test_dataset.py b/tests/test_dataset.py old mode 100644 new mode 100755 index dc3042175..3c55db54b --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -6,6 +6,8 @@ import sys from qlib.tests import TestAutoData from qlib.data.dataset import TSDatasetH import numpy as np +from torch.utils.data import DataLoader +import time class TestDataset(TestAutoData): @@ -20,6 +22,36 @@ class TestDataset(TestAutoData): "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": "csi300", + "infer_processors": [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "FilterCol", + "kwargs":{"col_list": ["RESI5", "WVMA5", "RSQR5"]} + }, + { + "class" : "CSZFillna", + "kwargs":{"fields_group": "feature"} + } + ], + "learn_processors": [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "DropnaProcessor", + "kwargs":{"fields_group": "feature"} + }, + "DropnaLabel", + { + "class": "CSZScoreNorm", + "kwargs": {"fields_group": "label"} + } + ], + "process_type": "independent" }, }, segments={ @@ -28,8 +60,12 @@ class TestDataset(TestAutoData): "test": ("2017-01-01", "2020-08-01"), }, ) - _ = tsdh.prepare("train") # Test the correctness + tsds_train = tsdh.prepare("train") # Test the correctness tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series + train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True) + for data in train_loader: + now = time.localtime() + print(time.strftime("%Y-%m-%d-%H_%M_%S", now)) # The dimension of sample is same as tabular data, but it will return timeseries data of the sample