diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index e5c37dac6..25b8b2a0c 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -3,20 +3,19 @@ """ This example shows how OnlineManager works with rolling tasks. -There are two parts including first train and routine. +There are four parts including first train, routine 1, add strategy and routine 2. Firstly, the OnlineManager will finish the first training and set trained models to `online` models. -Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models +Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals +Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies. +Finally, the OnlineManager will finish second routine and update all strategies. """ import os -from pathlib import Path -import pickle import fire import qlib from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen -from qlib.workflow.task.manage import TaskManager from qlib.workflow.online.manager import OnlineManager data_handler_config = { @@ -84,7 +83,8 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, - tasks=[task_xgboost_config, task_lgb_config], + tasks=[task_xgboost_config], + add_tasks=[task_lgb_config], ): mongo_conf = { "task_url": task_url, # your MongoDB url @@ -92,11 +92,12 @@ class RollingOnlineExample: } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.tasks = tasks + self.add_tasks = add_tasks self.rolling_step = rolling_step - strategy = [] + strategies = [] for task in tasks: name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy - strategy.append( + strategies.append( RollingStrategy( name_id, task, @@ -104,8 +105,7 @@ class RollingOnlineExample: ) ) - self.rolling_online_manager = OnlineManager(strategy) - self.collector = self.rolling_online_manager.get_collector() + self.rolling_online_manager = OnlineManager(strategies) _ROLLING_MANAGER_PATH = ( ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. @@ -113,40 +113,60 @@ class RollingOnlineExample: # Reset all things to the first status, be careful to save important data def reset(self): - for task in self.tasks: + for task in self.tasks + self.add_tasks: name_id = task["model"]["class"] - TaskManager(name_id).remove() exp = R.get_exp(experiment_name=name_id) for rid in exp.list_recorders(): exp.delete_recorder(rid) - if os.path.exists(self._ROLLING_MANAGER_PATH): - os.remove(self._ROLLING_MANAGER_PATH) + if os.path.exists(self._ROLLING_MANAGER_PATH): + os.remove(self._ROLLING_MANAGER_PATH) def first_run(self): print("========== reset ==========") self.reset() print("========== first_run ==========") self.rolling_online_manager.first_train() + print("========== collect results ==========") + print(self.rolling_online_manager.get_collector()()) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) - print("========== collect results ==========") - print(self.collector()) def routine(self): print("========== load ==========") - with Path(self._ROLLING_MANAGER_PATH).open("rb") as f: - self.rolling_online_manager = pickle.load(f) + self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) print("========== routine ==========") self.rolling_online_manager.routine() print("========== collect results ==========") - print(self.collector()) + print(self.rolling_online_manager.get_collector()()) print("========== signals ==========") print(self.rolling_online_manager.get_signals()) + print("========== dump ==========") + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) + + def add_strategy(self): + print("========== load ==========") + self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) + print("========== add strategy ==========") + strategies = [] + for task in self.add_tasks: + name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + strategies.append( + RollingStrategy( + name_id, + task, + RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD), + ) + ) + self.rolling_online_manager.add_strategy(strategies=strategies) + print("========== dump ==========") + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def main(self): self.first_run() self.routine() + self.add_strategy() + self.routine() if __name__ == "__main__": diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index f2a576560..6947d6678 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -262,12 +262,29 @@ class OnlineManager(Serializable): Prepare all models and signals if something is waiting for preparation. Args: - model_kwargs: the params for `prepare_online_models` + model_kwargs: the params for `end_train` signal_kwargs: the params for `prepare_signals` """ + last_models = {} + signals_time = D.calendar()[0] + need_prepare = False for cur_time, strategy_models in self.history.items(): self.cur_time = cur_time + for strategy, models in strategy_models.items(): - models = self.trainer.end_train(models, experiment_name=strategy.name_id) - # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way. - self.prepare_signals(**signal_kwargs) + # only new online models need to prepare + if last_models.setdefault(strategy, set()) != set(models): + models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs) + strategy.tool.reset_online_tag(models) + need_prepare = True + last_models[strategy] = set(models) + + if need_prepare: + # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way. + self.prepare_signals(**signal_kwargs) + if signals_time > cur_time: + self.logger.warn( + f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models." + ) + need_prepare = False + signals_time = self.signals.index.get_level_values("datetime").max() diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index b9b2fd1b3..0c9abf731 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -39,6 +39,9 @@ class Recorder: def __str__(self): return str(self.info) + def __hash__(self) -> int: + return hash(self.info["id"]) + @property def info(self): output = dict() @@ -232,6 +235,14 @@ class MLflowRecorder(Recorder): client=self.client, ) + def __hash__(self) -> int: + return hash(self.info["id"]) + + def __eq__(self, o: object) -> bool: + if isinstance(o, MLflowRecorder): + return self.info["id"] == o.info["id"] + return False + @property def uri(self): return self._uri