diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 9b61af47e..9dd4285da 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -12,11 +12,12 @@ from qlib.contrib.data.handler import Alpha158 from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData + class RollingDataWorkflow(object): MARKET = "csi300" start_time = "2010-01-01" - end_time = "2019-12-31" + end_time = "2019-12-31" rolling_cnt = 5 def _init_qlib(self): @@ -27,7 +28,7 @@ class RollingDataWorkflow(object): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - + def _dump_pre_handler(self, path): handler_config = { "class": "Alpha158", @@ -51,13 +52,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - train_start_time = (2010,1,1) - train_end_time = (2012,12,31) - valid_start_time = (2013,1,1) - valid_end_time = (2013,12,31) - test_start_time = (2014,1,1) - test_end_time = (2014,12,31) - + train_start_time = (2010, 1, 1) + train_end_time = (2012, 12, 31) + valid_start_time = (2013, 1, 1) + valid_end_time = (2013, 12, 31) + test_start_time = (2014, 1, 1) + test_end_time = (2014, 12, 31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -70,9 +71,9 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), - "data_loader_kwargs":{ + "data_loader_kwargs": { "handler_config": pre_handler, - } + }, }, }, "segments": { @@ -94,14 +95,23 @@ class RollingDataWorkflow(object): "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), - "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), - "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), + "train": ( + datetime(train_start_time[0] + 1, *train_start_time[1:]), + datetime(train_end_time[0], *train_end_time[1:]), + ), + "valid": ( + datetime(valid_start_time[0] + 1, *valid_start_time[1:]), + datetime(valid_end_time[0], *valid_end_time[1:]), + ), + "test": ( + datetime(test_start_time[0] + 1, *test_start_time[1:]), + datetime(test_end_time[0], *test_end_time[1:]), + ), }, ) - dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": @@ -147,4 +157,3 @@ if __name__ == "__main__": } dataset = init_instance_by_config(task["dataset"]) -