diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 5f024192f..8650859ff 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -5,6 +5,7 @@ This example is about how can simulate the OnlineManager based on rolling tasks. """ +from pprint import pprint import fire import qlib from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM @@ -13,7 +14,63 @@ from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE + +data_handler_config = { + "start_time": "2018-01-01", + "end_time": "2018-10-31", + "fit_start_time": "2018-01-01", + "fit_end_time": "2018-03-31", + "instruments": "csi100", +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2018-01-01", "2018-03-31"), + "valid": ("2018-04-01", "2018-05-31"), + "test": ("2018-06-01", "2018-09-10"), + }, + }, +} + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + +# use lgb model +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost model +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} class OnlineSimulationExample: @@ -46,7 +103,10 @@ class OnlineSimulationExample: tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: - tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] + #tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] + tasks = [task_xgboost_config, task_lgb_config] + #pprint(CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE) + #pprint(task_xgboost_config) self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index b4f7245b7..99a91e027 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -18,7 +18,7 @@ from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager -from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG +from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING class RollingOnlineExample: @@ -34,9 +34,9 @@ class RollingOnlineExample: add_tasks=None, ): if add_tasks is None: - add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG] + add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING] if tasks is None: - tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG] + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name diff --git a/qlib/tests/config.py b/qlib/tests/config.py index 80461f6f9..c61b5651e 100644 --- a/qlib/tests/config.py +++ b/qlib/tests/config.py @@ -43,17 +43,29 @@ RECORD_CONFIG = [ ] -def get_data_handler_config(market=CSI300_MARKET): +def get_data_handler_config( + start_time="2008-01-01", + end_time="2020-08-01", + fit_start_time="2008-01-01", + fit_end_time="2014-12-31", + instruments=CSI300_MARKET, +): return { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, + "start_time": start_time, + "end_time": end_time, + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + "instruments": instruments, } -def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS): +def get_dataset_config( + dataset_class=DATASET_ALPHA158_CLASS, + train=("2008-01-01", "2014-12-31"), + valid=("2015-01-01", "2016-12-31"), + test=("2017-01-01", "2020-08-01"), + handler_kwargs={"instruments": CSI300_MARKET}, +): return { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS "handler": { "class": dataset_class, "module_path": "qlib.contrib.data.handler", - "kwargs": get_data_handler_config(market), + "kwargs": get_data_handler_config(**handler_kwargs), }, "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "train": train, + "valid": valid, + "test": test, }, }, } -def get_gbdt_task(market=CSI300_MARKET): +def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": GBDT_MODEL, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), } -def get_record_lgb_config(market=CSI300_MARKET): +def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": { "class": "LGBModel", "module_path": "qlib.contrib.model.gbdt", }, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), "record": RECORD_CONFIG, } -def get_record_xgboost_config(market=CSI300_MARKET): +def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": { "class": "XGBModel", "module_path": "qlib.contrib.model.xgboost", }, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), "record": RECORD_CONFIG, } -CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET) -CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET) +CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET}) +CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET}) -CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET) -CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET) +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET}) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET}) + +# use for rolling_online_managment.py +ROLLING_HANDLER_CONFIG = { + "start_time": "2013-01-01", + "end_time": "2020-09-25", + "fit_start_time": "2013-01-01", + "fit_end_time": "2014-12-31", + "instruments": CSI100_MARKET, +} +ROLLING_DATASET_CONFIG = { + "train": ("2013-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2015-12-31"), + "test": ("2016-01-01", "2020-07-10"), +} +CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config( + dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG +) +CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config( + dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG +) + +# use for online_management_simulate.py +ONLINE_HANDLER_CONFIG = { + "start_time": "2018-01-01", + "end_time": "2018-10-31", + "fit_start_time": "2018-01-01", + "fit_end_time": "2018-03-31", + "instruments": CSI100_MARKET, +} +ONLINE_DATASET_CONFIG = { + "train": ("2018-01-01", "2018-03-31"), + "valid": ("2018-04-01", "2018-05-31"), + "test": ("2018-06-01", "2018-09-10"), +} +CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config( + dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG +) +CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config( + dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG +)