diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 8581f149b..62523aefd 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -3,8 +3,9 @@ import qlib import pickle -import datetime import pandas as pd + +from datetime import datetime from qlib.config import REG_CN from qlib.data.dataset.handler import DataHandlerLP from qlib.contrib.data.handler import Alpha158 @@ -14,7 +15,6 @@ from qlib.tests.data import GetData class RollingDataWorkflow(object): MARKET = "csi300" - start_time = "2010-01-01" end_time = "2019-12-31" rolling_cnt = 5 @@ -33,9 +33,9 @@ class RollingDataWorkflow(object): "class": "Alpha158", "module_path": "qlib.contrib.data.handler", "kwargs": { - "start_time": start_time, - "end_time": end_time, - "instruments": MARKET, + "start_time": self.start_time, + "end_time": self.end_time, + "instruments": self.MARKET, }, } pre_handler = init_instance_by_config(handler_config) @@ -51,10 +51,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - init_start_time = datetime.datetime(2010,1,1) - init_end_time = datetime.datetime(2014,12,31) - init_fit_end_time = datetime.datetime(2012,12,31) - + train_start_time = (2010,1,1) + train_end_time = (2012,12,31) + valid_start_time = (2013,1,1) + valid_end_time = (2013,12,31) + test_start_time = (2014,1,1) + test_end_time = (2014,12,31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -63,19 +66,19 @@ class RollingDataWorkflow(object): "class": "RollingDataHandler", "module_path": "rolling_handler", "kwargs": { - "start_time": init_start_time, - "end_time": init_start_time, - "fit_start_time": init_fit_start_time, - "fit_end_time": init_fit_end_time, + "start_time": datetime(*train_start_time), + "end_time": datetime(*test_end_time), + "fit_start_time": datetime(*train_start_time), + "fit_end_time": datetime(*train_end_time), "data_loader_kwargs":{ "handler_config": pre_handler, } }, }, "segments": { - "train": (init_start_time, init_fit_end_time), - "valid": (init_start_time, "2013-12-31"), - "test": (init_start_time, init_end_time), + "train": (datetime(*train_start_time), datetime(*train_end_time)), + "valid": (datetime(*valid_start_time), datetime(*valid_end_time)), + "test": (datetime(*test_start_time), datetime(*test_end_time)), }, }, } @@ -86,17 +89,19 @@ class RollingDataWorkflow(object): if rolling_offset: dataset.init( handler_kwargs={ - "init_type": DataHandlerLP.IT_FIT_IND, - "start_time": "2021-01-19 00:00:00", - "end_time": "2021-01-25 16:00:00", + "init_type": DataHandlerLP.IT_FIT_SEQ, + "start_time": datetime(train_start_time[0] + 1, *train_start_time[1:]), + "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": ("2010-01-01", "2012-12-31"), - "valid": ("2013-01-01", "2013-12-31"), - "test": ("2014-01-01", "2014-12-31"), + "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), + "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), + "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), }, ) + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__":