From a88697151aa34293984c5ef816c6848d8928e44f Mon Sep 17 00:00:00 2001 From: lwwang1995 Date: Sun, 6 Dec 2020 17:24:58 +0800 Subject: [PATCH] Test CSRankNorm. --- tests/test_dataset.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 36234c879..01454fff8 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -8,7 +8,7 @@ from qlib.data.dataset import TSDatasetH import numpy as np from torch.utils.data import DataLoader import time - +from qlib.data.dataset.handler import DataHandlerLP class TestDataset(TestAutoData): def testTSDataset(self): @@ -23,17 +23,14 @@ class TestDataset(TestAutoData): "fit_end_time": "2014-12-31", "instruments": "csi300", "infer_processors": [ - {"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}}, {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, - {"class": "CSZFillna", "kwargs": {"fields_group": "feature"}}, + {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":"true"}}, + {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ - {"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}}, - {"class": "DropnaProcessor", "kwargs": {"fields_group": "feature"}}, "DropnaLabel", - {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm ], - "process_type": "independent", }, }, segments={ @@ -42,8 +39,8 @@ class TestDataset(TestAutoData): "test": ("2017-01-01", "2020-08-01"), }, ) - tsds_train = tsdh.prepare("train") # Test the correctness - tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series + tsds_train = tsdh.prepare("train", data_key=DataHandlerLP.DK_L) # Test the correctness + tsds = tsdh.prepare("valid", data_key=DataHandlerLP.DK_L) t = time.time() for idx in np.random.randint(0, len(tsds_train), size=2000):