diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 8650859ff..bd7c4675d 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -16,62 +16,6 @@ from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE -data_handler_config = { - "start_time": "2018-01-01", - "end_time": "2018-10-31", - "fit_start_time": "2018-01-01", - "fit_end_time": "2018-03-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2018-01-01", "2018-03-31"), - "valid": ("2018-04-01", "2018-05-31"), - "test": ("2018-06-01", "2018-09-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} - class OnlineSimulationExample: def __init__( @@ -103,10 +47,7 @@ class OnlineSimulationExample: tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: - #tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] - tasks = [task_xgboost_config, task_lgb_config] - #pprint(CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE) - #pprint(task_xgboost_config) + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 99a91e027..6abbbfb0e 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -19,6 +19,7 @@ from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING +from qlib.workflow.task.manage import TaskManager class RollingOnlineExample: diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index ef6cb8dfa..dc1186038 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -163,17 +163,20 @@ class OnlineManager(Serializable): models = self.trainer.end_train(models, experiment_name=strategy.name_id) self.prepare_signals(**signal_kwargs) - def get_collector(self) -> MergeCollector: + def get_collector(self, **kwargs) -> MergeCollector: """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. This collector can be a basis as the signals preparation. + + Args: + **kwargs: the params for get_collector. Returns: MergeCollector: the collector to merge other collectors. """ collector_dict = {} for strategy in self.strategies: - collector_dict[strategy.name_id] = strategy.get_collector() + collector_dict[strategy.name_id] = strategy.get_collector(**kwargs) return MergeCollector(collector_dict, process_list=[]) def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):