From 68246b3b6d7037f3134ceb6e59aef869e96f1d8f Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:58:55 +0800 Subject: [PATCH] update workflow --- examples/rolling_process_data/workflow.py | 87 +++++------------------ 1 file changed, 18 insertions(+), 69 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 9dd4285da..2f48662bd 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import qlib +import fire import pickle import pandas as pd @@ -12,12 +13,11 @@ from qlib.contrib.data.handler import Alpha158 from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData - class RollingDataWorkflow(object): MARKET = "csi300" start_time = "2010-01-01" - end_time = "2019-12-31" + end_time = "2019-12-31" rolling_cnt = 5 def _init_qlib(self): @@ -28,7 +28,7 @@ class RollingDataWorkflow(object): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - + def _dump_pre_handler(self, path): handler_config = { "class": "Alpha158", @@ -52,13 +52,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - train_start_time = (2010, 1, 1) - train_end_time = (2012, 12, 31) - valid_start_time = (2013, 1, 1) - valid_end_time = (2013, 12, 31) - test_start_time = (2014, 1, 1) - test_end_time = (2014, 12, 31) - + train_start_time = (2010,1,1) + train_end_time = (2012,12,31) + valid_start_time = (2013,1,1) + valid_end_time = (2013,12,31) + test_start_time = (2014,1,1) + test_end_time = (2014,12,31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -71,9 +71,9 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), - "data_loader_kwargs": { + "data_loader_kwargs":{ "handler_config": pre_handler, - }, + } }, }, "segments": { @@ -95,65 +95,14 @@ class RollingDataWorkflow(object): "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": ( - datetime(train_start_time[0] + 1, *train_start_time[1:]), - datetime(train_end_time[0], *train_end_time[1:]), - ), - "valid": ( - datetime(valid_start_time[0] + 1, *valid_start_time[1:]), - datetime(valid_end_time[0], *valid_end_time[1:]), - ), - "test": ( - datetime(test_start_time[0] + 1, *test_start_time[1:]), - datetime(test_end_time[0], *test_end_time[1:]), - ), + "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), + "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), + "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), }, ) - dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": - - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) - - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - - dataset = init_instance_by_config(task["dataset"]) + fire.Fire(RollingDataWorkflow)