From 0058f7d0dcf29106f245ac4d69ec8e84ac2dcfa5 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Mon, 26 Apr 2021 09:31:47 +0000 Subject: [PATCH] Online Serving V8 --- .../online_srv/online_management_simulate.py | 68 ++---- .../online_srv/rolling_online_management.py | 5 + qlib/model/trainer.py | 224 ++++++++++++++---- qlib/workflow/__init__.py | 2 +- qlib/workflow/online/manager.py | 168 +++++++++---- qlib/workflow/online/simulator.py | 13 +- qlib/workflow/task/gen.py | 9 +- qlib/workflow/task/manage.py | 38 ++- 8 files changed, 368 insertions(+), 159 deletions(-) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 9b5fbcc03..1b1fed660 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -1,14 +1,14 @@ import fire import qlib from qlib.model.ens.ensemble import ens_workflow -from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM from qlib.workflow import R from qlib.workflow.online.manager import RollingOnlineManager from qlib.workflow.online.simulator import OnlineSimulator from qlib.workflow.task.collect import RecorderCollector from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager +from qlib.workflow.task.utils import list_recorders """ This examples is about the OnlineManager and OnlineSimulator based on rolling tasks. @@ -19,7 +19,7 @@ The OnlineSimulator will focus on the simulating real updating routine of your o data_handler_config = { "start_time": "2018-01-01", - "end_time": None, # "2018-10-31", + "end_time": "2018-10-31", "fit_start_time": "2018-01-01", "fit_end_time": "2018-03-31", "instruments": "csi100", @@ -74,7 +74,7 @@ task_xgboost_config = { } -class OnlineManagerExample: +class OnlineSimulationExample: def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", @@ -86,6 +86,7 @@ class OnlineManagerExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", + tasks=[task_xgboost_config], # , task_lgb_config] ): """ init OnlineManagerExample. @@ -100,6 +101,7 @@ class OnlineManagerExample: rolling_step (int, optional): the step for rolling. Defaults to 80. start_time (str, optional): the start time of simulating. Defaults to "2018-09-10". end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". + tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ self.exp_name = exp_name self.task_pool = task_pool @@ -108,76 +110,49 @@ class OnlineManagerExample: "task_db_name": task_db_name, } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) - - self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) # The rolling tasks generator - self.trainer = TrainerRM(self.exp_name, self.task_pool) # The trainer based on (R)ecorder and Task(M)anager + self.rolling_gen = RollingGen( + step=rolling_step, rtype=RollingGen.ROLL_SD, modify_end_time=False + ) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31. + self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks - self.collector = RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key) # The result collector - self.grouper = RollingGroup() # Divide your results into different rolling group self.rolling_online_manager = RollingOnlineManager( experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer, - collector=self.collector, need_log=False, ) # The OnlineManager based on Rolling self.onlinesimulator = OnlineSimulator( start_time=start_time, end_time=end_time, - onlinemanager=self.rolling_online_manager, + online_manager=self.rolling_online_manager, ) + self.tasks = tasks # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") self.task_manager.remove() + exp = R.get_exp(experiment_name=self.exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) - @staticmethod - def rec_key(recorder): - """ - given a Recorder and return its key to identify it - - Args: - recorder (Recorder): a instance of the Recorder - - Returns: - tuple: (model_key, rolling_key) - """ - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key - - def result_collecting(self): - print("========== result collecting ==========") - - # ens_workflow can help collect, group and ensemble results in a easy way - artifact = ens_workflow(self.rolling_online_manager.get_collector(), self.grouper) - print(artifact) + for rid in list_recorders( + RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False + ): + exp.delete_recorder(rid) # Run this firstly to see the workflow in OnlineManager def first_train(self): print("========== first train ==========") self.reset() - - tasks = task_generator( - tasks=[task_xgboost_config, task_lgb_config], - generators=[self.rolling_gen], # generate different date segment - ) - - self.rolling_online_manager.prepare_new_models(tasks=tasks, tag=RollingOnlineManager.ONLINE_TAG) - self.result_collecting() + self.rolling_online_manager.first_train(self.tasks) # Run this secondly to see the simulating in OnlineSimulator def simulate(self): - print("========== simulate ==========") self.onlinesimulator.simulate() - - self.result_collecting() + print(self.rolling_online_manager.collect_artifact()) print("========== online models ==========") recs_dict = self.onlinesimulator.online_models() @@ -186,6 +161,9 @@ class OnlineManagerExample: for rec in recs: print(rec.info["id"]) + print("========== online signals ==========") + print(self.rolling_online_manager.get_signals()) + # Run this to run all workflow automaticly def main(self): self.first_train() @@ -195,4 +173,4 @@ class OnlineManagerExample: if __name__ == "__main__": ## to run all workflow automaticly with your own parameters, use the command below # python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60 - fire.Fire(OnlineManagerExample) + fire.Fire(OnlineSimulationExample) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 6c30f3af3..d118afe75 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -111,6 +111,11 @@ class RollingOnlineExample: if os.path.exists(self._ROLLING_MANAGER_PATH): os.remove(self._ROLLING_MANAGER_PATH) + for rid in list_recorders( + RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False + ): + exp.delete_recorder(rid) + def first_run(self): print("========== first_run ==========") self.reset() diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 348f6b521..af65c5886 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import copy +import time +from xxlimited import Str from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs from qlib.workflow import R from qlib.workflow.recorder import Recorder @@ -11,6 +14,63 @@ from qlib.model.base import Model import socket +def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder: + """ + Begin a task training with starting a recorder and saving the task config. + + Args: + task_config (dict) + experiment_name (str) + + Returns: + Recorder + """ + with R.start(experiment_name=experiment_name, recorder_name=str(time.time())): + R.log_params(**flatten_dict(task_config)) + R.save_objects(**{"task": task_config}) # keep the original format and datatype + R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"}) + recorder: Recorder = R.get_recorder() + return recorder + + +def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs): + """ + Finished task training with real model fitting and saving. + + Args: + rec (Recorder): This recorder will be resumed + experiment_name (str) + + Returns: + Recorder + """ + with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True): + task_config = R.load_object("task") + # model & dataset initiaiton + model: Model = init_instance_by_config(task_config["model"]) + dataset: Dataset = init_instance_by_config(task_config["dataset"]) + # model training + model.fit(dataset) + R.save_objects(**{"params.pkl": model}) + # This dataset is saved for online inference. So the concrete data should not be dumped + dataset.config(dump_all=False, recursive=True) + R.save_objects(**{"dataset": dataset}) + # generate records: prediction, backtest, and analysis + records = task_config.get("record", []) + if isinstance(records, dict): # prevent only one dict + records = [records] + for record in records: + cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") + if cls is SignalRecord: + rconf = {"model": model, "dataset": dataset, "recorder": rec} + else: + rconf = {"recorder": rec} + r = cls(**kwargs, **rconf) + r.generate() + R.set_tags(**{"train_status": "end_task_train"}) + return rec + + def task_train(task_config: dict, experiment_name: str) -> Recorder: """ task based training @@ -26,36 +86,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder: ---------- Recorder : The instance of the recorder """ - # model initiaiton - model: Model = init_instance_by_config(task_config["model"]) - dataset: Dataset = init_instance_by_config(task_config["dataset"]) - - # start exp - with R.start(experiment_name=experiment_name): - - # train model - R.log_params(**flatten_dict(task_config)) - R.save_objects(**{"task": task_config}) # keep the original format and datatype - R.set_tags(hostname=socket.gethostname()) - model.fit(dataset) - R.save_objects(**{"params.pkl": model}) - # This dataset is saved for online inference. So the concrete data should not be dumped - dataset.config(dump_all=False, recursive=True) - R.save_objects(**{"dataset": dataset}) - - # generate records: prediction, backtest, and analysis - records = task_config.get("record", []) - recorder: Recorder = R.get_recorder() - if isinstance(records, dict): # prevent only one dict - records = [records] - for record in records: - cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") - if cls is SignalRecord: - rconf = {"model": model, "dataset": dataset, "recorder": recorder} - else: - rconf = {"recorder": recorder} - r = cls(**kwargs, **rconf) - r.generate() + recorder = begin_task_train(task_config, experiment_name) + recorder = end_task_train(recorder, experiment_name) return recorder @@ -64,14 +96,22 @@ class Trainer: The trainer which can train a list of model """ - def train(self, *args, **kwargs): - """Given a list of model definition, finished training and return the results of them. + def train(self, tasks: list, *args, **kwargs): + """Given a list of model definition, begin a training and return the models. Returns: - list: a list of trained results + list: a list of models """ raise NotImplementedError(f"Please implement the `train` method.") + def end_train(self, models, *args, **kwargs): + """Given a list of models, finished something in the end of training if you need. + + Returns: + list: a list of models + """ + pass + class TrainerR(Trainer): """Trainer based on (R)ecorder. @@ -112,7 +152,15 @@ class TrainerRM(Trainer): self.task_pool = task_pool self.train_func = train_func - def train(self, tasks: list, train_func=None, *args, **kwargs): + def train( + self, + tasks: list, + train_func=None, + before_status=TaskManager.STATUS_WAITING, + after_status=TaskManager.STATUS_DONE, + *args, + **kwargs, + ): """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. This method defaults to a single process, but TaskManager offered a great way to parallel training. @@ -129,7 +177,15 @@ class TrainerRM(Trainer): train_func = self.train_func tm = TaskManager(task_pool=self.task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB - run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs) + run_task( + train_func, + self.task_pool, + experiment_name=self.experiment_name, + before_status=before_status, + after_status=after_status, + *args, + **kwargs, + ) recs = [] for _id in _id_list: @@ -137,10 +193,96 @@ class TrainerRM(Trainer): return recs -class DelayTrainer(Trainer): - def fake_train(self): - self.fake_trained = [] +class DelayTrainerR(TrainerR): + """ + A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. - def train(self): - for rec in self.fake_trained: - pass + """ + + def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): + super().__init__(experiment_name, train_func) + self.end_train_func = end_train_func + self.recs = [] + + def train(self, tasks: list, train_func, *args, **kwargs): + """ + Same as `train` of TrainerR, the results will be recorded in self.recs + + Args: + tasks (list): a list of definition based on `task` dict + train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + + Returns: + list: a list of Recorders + """ + self.recs = super().train(tasks, train_func=train_func, *args, **kwargs) + return self.recs + + def end_train(self, recs=None, end_train_func=None): + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finished real data loading and model fitting. + + Args: + recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs. + end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. + + Returns: + list: a list of Recorders + """ + if recs is None: + recs = copy.deepcopy(self.recs) + # the models will be only trained once + self.recs = [] + if end_train_func is None: + end_train_func = self.end_train_func + for rec in recs: + end_train_func(rec) + return recs + + +class DelayTrainerRM(TrainerRM): + """ + A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. + + """ + + def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train): + super().__init__(experiment_name, task_pool, train_func) + self.end_train_func = end_train_func + + def train(self, tasks: list, train_func=None, *args, **kwargs): + """ + Same as `train` of TrainerRM, the results will be recorded in self.recs + + Args: + tasks (list): a list of definition based on `task` dict + train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + + Returns: + list: a list of Recorders + """ + return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs) + + def end_train(self, recs, end_train_func=None): + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finished real data loading and model fitting. + + Args: + recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.. + end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. + + Returns: + list: a list of Recorders + """ + + if end_train_func is None: + end_train_func = self.end_train_func + run_task( + end_train_func, + self.task_pool, + experiment_name=self.experiment_name, + before_status=TaskManager.STATUS_PART_DONE, + ) + return recs diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index a03665626..46f9c563f 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -304,7 +304,7 @@ class QlibRecorder: """ self.exp_manager.set_uri(uri) - def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): + def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder: """ Method for retrieving a recorder. diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index e74488040..c94cf2455 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -44,17 +44,21 @@ class OnlineManager(Serializable): self.trainer = trainer self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log - self.delay_signals = {} self.cur_time = None - def prepare_signals(self, *args, **kwargs): + def prepare_signals(self): """ After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. Must use `pass` even though there is nothing to do. """ - raise NotImplementedError(f"Please implement the `prepare_signals` method.") + def get_signals(self): + """ + After preparing signals, here is the method to get them. + """ + raise NotImplementedError(f"Please implement the `get_signals` method.") + def prepare_tasks(self, *args, **kwargs): """ After the end of a routine, check whether we need to prepare and train some new tasks. @@ -62,7 +66,7 @@ class OnlineManager(Serializable): """ raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None): + def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs): """ Use trainer to train a list of tasks and set the trained model to `tag`. @@ -75,13 +79,14 @@ class OnlineManager(Serializable): check_func: the method to judge if a model can be online. The parameter is the model record and return True for online. None for online every models. + *args, **kwargs: will be passed to end_train which means will be passed to customized train method. """ if check_func is None: check_func = lambda x: True if len(tasks) > 0: if self.trainer is not None: - new_models = self.trainer.train(tasks) + new_models = self.trainer.train(tasks, *args, **kwargs) if check_func(new_models): self.set_online_tag(tag, new_models) if self.need_log: @@ -89,13 +94,13 @@ class OnlineManager(Serializable): else: self.logger.warn("No trainer to train new tasks.") - def update_online_pred(self, *args, **kwargs): + def update_online_pred(self): """ After the end of a routine, update the predictions of online models to latest. """ raise NotImplementedError(f"Please implement the `update_online_pred` method.") - def set_online_tag(self, tag, *args, **kwargs): + def set_online_tag(self, tag, recorder): """ Set `tag` to the model to sign whether online. @@ -104,15 +109,21 @@ class OnlineManager(Serializable): """ raise NotImplementedError(f"Please implement the `set_online_tag` method.") - def get_online_tag(self, *args, **kwargs): + def get_online_tag(self): """ Given a model and return its online tag. """ raise NotImplementedError(f"Please implement the `get_online_tag` method.") - def reset_online_tag(self, *args, **kwargs): - """ - Offline all models and set the models to 'online'. + def reset_online_tag(self, recorders=None): + """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. + + Args: + recorders (List, optional): + the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. + + Returns: + list: new online recorder. [] if there is no update. """ raise NotImplementedError(f"Please implement the `reset_online_tag` method.") @@ -137,31 +148,46 @@ class OnlineManager(Serializable): """ raise NotImplementedError(f"Please implement the `get_collector` method.") - def run_delay_signals(self): + def delay_prepare(self, rec_dict, *args, **kwargs): """ - Prepare all signals if there are some dates waiting for prepare. + Prepare all models and signals if there are something waiting for prepare. + NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. + + Args: + rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. + *args, **kwargs: will be passed to end_train which means will be passed to customized train method. """ - for cur_time, params in self.delay_signals.items(): - self.cur_time = cur_time - self.prepare_signals(*params[0], **params[1]) - self.delay_signals = {} + for time_segment, recs_list in rec_dict.items(): + self.trainer.end_train(recs_list, *args, **kwargs) + self.reset_online_tag(recs_list) + self.prepare_signals() + signal_max = self.get_signals().index.get_level_values("datetime").max() + if time_segment[1] is not None and signal_max > time_segment[1]: + raise ValueError( + f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}" + ) def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs): """ The typical update process after a routine, such as day by day or month by month. update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models + + NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions. + + Args: + cur_time ([type], optional): [description]. Defaults to None. + delay_prepare (bool, optional): [description]. Defaults to False. + *args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config. + + Returns: + [type]: [description] """ self.cur_time = cur_time # None for latest date - self.update_online_pred() if not delay_prepare: - self.prepare_signals(*args, **kwargs) - else: - if cur_time is not None: - self.delay_signals[cur_time] = (args, kwargs) - else: - raise ValueError("Can not delay prepare when cur_time is None") + self.update_online_pred() + self.prepare_signals() tasks = self.prepare_tasks(*args, **kwargs) - self.prepare_new_models(tasks) + self.prepare_new_models(tasks, *args, **kwargs) return self.reset_online_tag() @@ -185,8 +211,16 @@ class OnlineManagerR(OnlineManager): trainer = TrainerR(experiment_name) super().__init__(trainer=trainer, need_log=need_log) self.exp_name = experiment_name + self.signal_rec = None def set_online_tag(self, tag, recorder: Union[Recorder, List]): + """ + Set `tag` to the model to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[Recorder, List]) + """ if isinstance(recorder, Recorder): recorder = [recorder] for rec in recorder: @@ -195,6 +229,15 @@ class OnlineManagerR(OnlineManager): self.logger.info(f"Set {len(recorder)} models to '{tag}'.") def get_online_tag(self, recorder: Recorder): + """ + Given a model and return its online tag. + + Args: + recorder (Recorder): a instance of recorder + + Returns: + str: the tag + """ tags = recorder.list_tags() return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) @@ -202,7 +245,7 @@ class OnlineManagerR(OnlineManager): """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. Args: - recorders (Union[List, Dict], optional): + recorders (Union[Recorder, List], optional): the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. Returns: @@ -225,7 +268,30 @@ class OnlineManagerR(OnlineManager): self.set_online_tag(OnlineManager.ONLINE_TAG, recorder) return recorder + def get_signals(self): + """ + get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) + + Returns: + signals + """ + if self.signal_rec is None: + with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): + self.signal_rec = R.get_recorder() + signals = None + try: + signals = self.signal_rec.load_object("signals") + except OSError: + self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?") + return signals + def online_models(self): + """ + Return online models. + + Returns: + list: the list of online models + """ return list( list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values() ) @@ -245,34 +311,35 @@ class OnlineManagerR(OnlineManager): """ Average the predictions of online models and offer a trading signals every routine. The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` - + Even if the latest signal already exists, the latest calculation result will be overwritten. + NOTE: Given a prediction of a certain time, all signals before this time will be prepared well. Args: over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. """ + if self.signal_rec is None: + with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): + self.signal_rec = R.get_recorder() - with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): - recorder = R.get_recorder() - pred = [] + pred = [] + try: + old_signals = self.signal_rec.load_object("signals") + except OSError: + old_signals = None - try: - old_signals = recorder.load_object("signals") - except OSError: - old_signals = None + for rec in self.online_models(): + pred.append(rec.load_object("pred.pkl")) - for rec in self.online_models(): - pred.append(rec.load_object("pred.pkl")) - - signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") - signals = signals.sort_index() - if old_signals is not None and not over_write: - # signals = old_signals.reindex(signals.index).combine_first(signals) - old_max = old_signals.index.get_level_values("datetime").max() - new_signals = signals.loc[old_max:] - signals = pd.concat([old_signals, new_signals], axis=0) - else: - new_signals = signals + signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") + signals = signals.sort_index() + if old_signals is not None and not over_write: + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + if self.need_log: self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.") - recorder.save_objects(**{"signals": signals}) + self.signal_rec.save_objects(**{"signals": signals}) class RollingOnlineManager(OnlineManagerR): @@ -304,7 +371,9 @@ class RollingOnlineManager(OnlineManagerR): def get_collector(self, rec_key_func=None, rec_filter_func=None): """ - get the instance of collector to collect results + Get the instance of collector to collect results. The returned collector must can distinguish results in different models. + Assumption: the models can be distinguished based on model name and rolling test segments. + If you do not want this assumption, please implement your own method or use another rec_key_func. Args: rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. @@ -353,10 +422,9 @@ class RollingOnlineManager(OnlineManagerR): generators=self.rg, # generate different date segment ) self.prepare_new_models(tasks, tag=self.ONLINE_TAG) - self.prepare_signals(over_write=True) return self.get_collector() - def prepare_tasks(self, *args, **kwargs): + def prepare_tasks(self): """ Prepare new tasks based on new date. diff --git a/qlib/workflow/online/simulator.py b/qlib/workflow/online/simulator.py index 16628c240..d45b7d99d 100644 --- a/qlib/workflow/online/simulator.py +++ b/qlib/workflow/online/simulator.py @@ -12,7 +12,7 @@ class OnlineSimulator: self, start_time, end_time, - onlinemanager: OnlineManager, + online_manager: OnlineManager, frequency="day", ): """ @@ -28,15 +28,14 @@ class OnlineSimulator: self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency) self.start_time = self.cal[0] self.end_time = self.cal[-1] - self.olm = onlinemanager - + self.olm = online_manager if len(self.cal) == 0: self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.") def simulate(self, *args, **kwargs): """ Starting from start time, this method will simulate every routine in OnlineManager. - NOTE: Considering the parallel training, the signals will be perpared after all routine simulating. + NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. Returns: Collector: the OnlineManager's collector @@ -54,12 +53,10 @@ class OnlineSimulator: self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders tmp_begin = cur_time prev_recorders = recorders - self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders - # prepare signals again incase there is no trained model when call it - self.olm.run_delay_signals() + # finished perparing models (and pred) and signals + self.olm.delay_prepare(self.rec_dict) self.logger.info(f"Finished preparing signals") - return self.olm.get_collector() def online_models(self): diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 9e273b74f..158bc9916 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -91,7 +91,7 @@ class RollingGen(TaskGen): ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date - def __init__(self, step: int = 40, rtype: str = ROLL_EX): + def __init__(self, step: int = 40, rtype: str = ROLL_EX, modify_end_time=True): """ Generate tasks for rolling @@ -101,9 +101,12 @@ class RollingGen(TaskGen): step to rolling rtype : str rolling type (expanding, sliding) + modify_end_time: bool + Whether the data set configuration needs to be modified when the required scope exceeds the original data set scope """ self.step = step self.rtype = rtype + self.modify_end_time = modify_end_time # TODO: Ask pengrong to update future date in dataset self.ta = TimeAdjuster(future=True) @@ -113,7 +116,6 @@ class RollingGen(TaskGen): def generate(self, task: dict): """ Converting the task into a rolling task. - # FIXME: only modify dataset layer, user need to change datahandler firstly. Parameters ---------- @@ -196,7 +198,8 @@ class RollingGen(TaskGen): t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) # if end_time < the end of test_segments, then change end_time to allow load more data if ( - self.ta.cal_interval( + self.modify_end_time + and self.ta.cal_interval( t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], t["dataset"]["kwargs"]["segments"][self.test_key][1], ) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index b144a8872..9d50d8563 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -174,11 +174,11 @@ class TaskManager: return _id_list - def fetch_task(self, query={}): + def fetch_task(self, query={}, status=STATUS_WAITING): query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) - query.update({"status": self.STATUS_WAITING}) + query.update({"status": status}) task = self.task_pool.find_one_and_update( query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] ) @@ -189,7 +189,7 @@ class TaskManager: return self._decode_task(task) @contextmanager - def safe_fetch_task(self, query={}): + def safe_fetch_task(self, query={}, status=STATUS_WAITING): """ fetch task from task_pool using query with contextmanager @@ -202,7 +202,7 @@ class TaskManager: ------- """ - task = self.fetch_task(query=query) + task = self.fetch_task(query=query, status=status) try: yield task except Exception: @@ -330,7 +330,15 @@ class TaskManager: return f"TaskManager({self.task_pool})" -def run_task(task_func, task_pool, force_release=False, *args, **kwargs): +def run_task( + task_func, + task_pool, + force_release=False, + before_status=TaskManager.STATUS_WAITING, + after_status=TaskManager.STATUS_DONE, + *args, + **kwargs, +): """ While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool @@ -352,16 +360,24 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): ever_run = False while True: - with tm.safe_fetch_task() as task: + with tm.safe_fetch_task(status=before_status) as task: if task is None: break get_module_logger("run_task").info(task["def"]) - if force_release: - with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means? - res = executor.submit(task_func, task["def"], *args, **kwargs).result() + # when fetching `WAITING` task, use task_def to train + if before_status == TaskManager.STATUS_WAITING: + param = task["def"] + # when fetching `PART_DONE` task, use task_res to train for the result has been saved + elif before_status == TaskManager.STATUS_PART_DONE: + param = task["res"] else: - res = task_func(task["def"], *args, **kwargs) - tm.commit_task_res(task, res) + raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!") + if force_release: + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + res = executor.submit(task_func, param, *args, **kwargs).result() + else: + res = task_func(param, *args, **kwargs) + tm.commit_task_res(task, res, status=after_status) ever_run = True return ever_run