From 9dfd001f6fafbe077dcdc30feacd0dbeb7bf31e6 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Fri, 7 May 2021 09:59:15 +0000 Subject: [PATCH] online serving v10 --- docs/advanced/task_management.rst | 1 + .../online_srv/online_management_simulate.py | 21 +- .../online_srv/rolling_online_management.py | 13 +- qlib/model/ens/ensemble.py | 12 +- qlib/model/ens/group.py | 20 +- qlib/model/trainer.py | 135 +++++++---- qlib/utils/__init__.py | 2 +- qlib/utils/serial.py | 36 ++- qlib/workflow/online/manager.py | 209 +++++++++++++----- qlib/workflow/online/strategy.py | 172 +++----------- qlib/workflow/online/utils.py | 4 +- qlib/workflow/task/collect.py | 85 +++---- qlib/workflow/task/gen.py | 53 +++-- qlib/workflow/task/manage.py | 8 +- 14 files changed, 426 insertions(+), 345 deletions(-) diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst index a68c12627..d60049455 100644 --- a/docs/advanced/task_management.rst +++ b/docs/advanced/task_management.rst @@ -55,6 +55,7 @@ More information of ``Task Manager`` can be found in `here <../reference/api.htm Task Training =============== +#FIXME: Trainer After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status. ``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed. An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly. diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 7be46d999..5583ee160 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -8,6 +8,7 @@ This examples is about how can simulate the OnlineManager based on rolling tasks import fire import qlib from qlib.model.trainer import DelayTrainerRM +from qlib.workflow import R from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingAverageStrategy from qlib.workflow.task.gen import RollingGen @@ -110,23 +111,29 @@ class OnlineSimulationExample: } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.rolling_gen = RollingGen( - step=rolling_step, rtype=RollingGen.ROLL_SD, modify_end_time=False - ) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31. + step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None + ) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time. self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks self.rolling_online_manager = OnlineManager( - RollingAverageStrategy( - exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False - ), + RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen, need_log=False), + trainer=self.trainer, begin_time=self.start_time, need_log=False, ) self.tasks = tasks + # Reset all things to the first status, be careful to save important data + def reset(self): + TaskManager(self.task_pool).remove() + exp = R.get_exp(experiment_name=self.exp_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + # Run this to run all workflow automatically def main(self): print("========== reset ==========") - self.rolling_online_manager.reset() + self.reset() print("========== simulate ==========") self.rolling_online_manager.simulate(end_time=self.end_time) print("========== collect results ==========") @@ -134,7 +141,7 @@ class OnlineSimulationExample: print("========== signals ==========") print(self.rolling_online_manager.get_signals()) print("========== online history ==========") - print(self.rolling_online_manager.get_online_history(self.exp_name)) + print(self.rolling_online_manager.history) if __name__ == "__main__": diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 25b6fc4da..ebf1ab59a 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -18,8 +18,6 @@ from qlib.workflow.online.strategy import RollingAverageStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager from qlib.workflow.online.manager import OnlineManager -from qlib.workflow.task.utils import list_recorders -from qlib.model.trainer import TrainerRM data_handler_config = { "start_time": "2013-01-01", @@ -86,7 +84,7 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, - tasks=[task_xgboost_config], # , task_lgb_config], + tasks=[task_xgboost_config, task_lgb_config], ): mongo_conf = { "task_url": task_url, # your MongoDB url @@ -103,7 +101,6 @@ class RollingOnlineExample: name_id, task, RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), - TrainerRM(experiment_name=name_id, task_pool=name_id), ) ) @@ -116,9 +113,8 @@ class RollingOnlineExample: # Reset all things to the first status, be careful to save important data def reset(self): - print("========== reset ==========") for task in self.tasks: - name_id = task["model"]["class"] + "_" + str(self.rolling_step) + name_id = task["model"]["class"] TaskManager(name_id).remove() exp = R.get_exp(experiment_name=name_id) for rid in exp.list_recorders(): @@ -127,12 +123,9 @@ class RollingOnlineExample: if os.path.exists(self._ROLLING_MANAGER_PATH): os.remove(self._ROLLING_MANAGER_PATH) - for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False): - exp.delete_recorder(rid) - def first_run(self): print("========== reset ==========") - self.rolling_online_manager.reset() + self.reset() print("========== first_run ==========") self.rolling_online_manager.first_train() print("========== dump ==========") diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 1fb14a37b..a7b837ea5 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -7,6 +7,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many su from typing import Union import pandas as pd +from qlib.utils import flatten_dict class Ensemble: @@ -77,19 +78,22 @@ class RollingEnsemble(Ensemble): class AverageEnsemble(Ensemble): def __call__(self, ensemble_dict: dict): """ - Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble. + Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble. - NOTE: The values of dict must be pd.DataFrame, and have the index "datetime" + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it. Args: ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. The key of the dict will be ignored. Returns: - pd.DataFrame: the complete result of averaging. + pd.DataFrame: the complete result of averaging and standardizing. """ + # need to flatten the nested dict + ensemble_dict = flatten_dict(ensemble_dict) values = list(ensemble_dict.values()) results = pd.concat(values, axis=1) - results = results.mean(axis=1).to_frame("score") + results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std()) + results = results.mean(axis=1) results = results.sort_index() return results diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index d8f174105..a00a8ea0e 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -36,20 +36,36 @@ class Group: self._ens_func = ens def group(self, *args, **kwargs) -> dict: - # TODO: such design is weird when `_group_func` is the only configurable part in the class + """ + Group a set of object and change them to a dict. + + For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} + + Returns: + dict: grouped dict + """ if isinstance(getattr(self, "_group_func", None), Callable): return self._group_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `group_func`.") def reduce(self, *args, **kwargs) -> dict: + """ + Reduce grouped dict in some way. + + For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object} + + Returns: + dict: reduced dict + """ if isinstance(getattr(self, "_ens_func", None), Callable): return self._ens_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `_ens_func`.") def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict: - """Group the ungrouped_dict into different groups. + """ + Group the ungrouped_dict into different groups. Args: ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things} diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 7680674a6..68b78d9df 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -12,7 +12,6 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model """ import socket -import time from typing import Callable, List from qlib.data.dataset import Dataset @@ -145,12 +144,6 @@ class Trainer: """ return self.delay - def reset(self): - """ - Reset the Trainer status. - """ - pass - class TrainerR(Trainer): """ @@ -160,42 +153,52 @@ class TrainerR(Trainer): Assumption: models were defined by `task` and the results will saved to `Recorder` """ - def __init__(self, experiment_name: str, train_func: Callable = task_train): + # Those tag will help you distinguish whether the Recorder has finished traning + STATUS_KEY = "train_status" + STATUS_BEGIN = "begin_task_train" + STATUS_END = "end_task_train" + + def __init__(self, experiment_name: str = None, train_func: Callable = task_train): """ Init TrainerR. Args: - experiment_name (str): the name of experiment. + experiment_name (str, optional): the default name of experiment. train_func (Callable, optional): default training method. Defaults to `task_train`. """ super().__init__() self.experiment_name = experiment_name self.train_func = train_func - def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]: + def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. Args: tasks (list): a list of definition based on `task` dict train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + experiment_name (str): the experiment name, None for use default name. kwargs: the params for train_func. Returns: list: a list of Recorders """ + if len(tasks) == 0: + return [] if train_func is None: train_func = self.train_func + if experiment_name is None: + experiment_name = self.experiment_name recs = [] for task in tasks: - rec = train_func(task, self.experiment_name, **kwargs) - rec.set_tags(**{"train_status": "begin_task_train"}) + rec = train_func(task, experiment_name, **kwargs) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) recs.append(rec) return recs - def end_train(self, recs: list, **kwargs) -> list: + def end_train(self, recs: list, **kwargs) -> List[Recorder]: for rec in recs: - rec.set_tags(**{"train_status": "end_task_train"}) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs @@ -204,12 +207,12 @@ class DelayTrainerR(TrainerR): A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. """ - def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): + def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train): """ Init TrainerRM. Args: - experiment_name (str): the name of experiment. + experiment_name (str): the default name of experiment. train_func (Callable, optional): default train method. Defaults to `begin_task_train`. end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. """ @@ -217,7 +220,7 @@ class DelayTrainerR(TrainerR): self.end_train_func = end_train_func self.delay = True - def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]: + def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. @@ -225,6 +228,7 @@ class DelayTrainerR(TrainerR): Args: recs (list): a list of Recorder, the tasks have been saved to them end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. Returns: @@ -232,9 +236,13 @@ class DelayTrainerR(TrainerR): """ if end_train_func is None: end_train_func = self.end_train_func + if experiment_name is None: + experiment_name = self.experiment_name for rec in recs: - end_train_func(rec, **kwargs) - rec.set_tags(**{"train_status": "end_task_train"}) + if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END: + continue + end_train_func(rec, experiment_name, **kwargs) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs @@ -246,13 +254,18 @@ class TrainerRM(Trainer): Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager """ - def __init__(self, experiment_name: str, task_pool: str, train_func=task_train): + # Those tag will help you distinguish whether the Recorder has finished traning + STATUS_KEY = "train_status" + STATUS_BEGIN = "begin_task_train" + STATUS_END = "end_task_train" + + def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train): """ Init TrainerR. Args: - experiment_name (str): the name of experiment. - task_pool (str): task pool name in TaskManager. + experiment_name (str): the default name of experiment. + task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. train_func (Callable, optional): default training method. Defaults to `task_train`. """ super().__init__() @@ -264,6 +277,7 @@ class TrainerRM(Trainer): self, tasks: list, train_func: Callable = None, + experiment_name: str = None, before_status: str = TaskManager.STATUS_WAITING, after_status: str = TaskManager.STATUS_DONE, **kwargs, @@ -277,6 +291,7 @@ class TrainerRM(Trainer): Args: tasks (list): a list of definition based on `task` dict train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + experiment_name (str): the experiment name, None for use default name. before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. kwargs: the params for train_func. @@ -284,14 +299,21 @@ class TrainerRM(Trainer): Returns: list: a list of Recorders """ + if len(tasks) == 0: + return [] if train_func is None: train_func = self.train_func - tm = TaskManager(task_pool=self.task_pool) + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + tm = TaskManager(task_pool=task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB run_task( train_func, - self.task_pool, - experiment_name=self.experiment_name, + task_pool, + experiment_name=experiment_name, before_status=before_status, after_status=after_status, **kwargs, @@ -300,23 +322,15 @@ class TrainerRM(Trainer): recs = [] for _id in _id_list: rec = tm.re_query(_id)["res"] - rec.set_tags(**{"train_status": "begin_task_train"}) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) recs.append(rec) return recs def end_train(self, recs: list, **kwargs) -> list: for rec in recs: - rec.set_tags(**{"train_status": "end_task_train"}) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs - def reset(self): - """ - .. note:: - this method will delete all task in this task_pool! - """ - tm = TaskManager(task_pool=self.task_pool) - tm.remove() - class DelayTrainerRM(TrainerRM): """ @@ -324,30 +338,57 @@ class DelayTrainerRM(TrainerRM): """ - def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train): + def __init__( + self, + experiment_name: str = None, + task_pool: str = None, + train_func=begin_task_train, + end_train_func=end_task_train, + ): + """ + Init DelayTrainerRM. + + Args: + experiment_name (str): the default name of experiment. + task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. + train_func (Callable, optional): default train method. Defaults to `begin_task_train`. + end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + """ super().__init__(experiment_name, task_pool, train_func) self.end_train_func = end_train_func self.delay = True - def train(self, tasks: list, train_func=None, **kwargs): + def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs): """ Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE. Args: tasks (list): a list of definition based on `task` dict train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func. + experiment_name (str): the experiment name, None for use default name. Returns: list: a list of Recorders """ - return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs) + if len(tasks) == 0: + return [] + return super().train( + tasks, + train_func=train_func, + experiment_name=experiment_name, + after_status=TaskManager.STATUS_PART_DONE, + **kwargs, + ) - def end_train(self, recs, end_train_func=None, **kwargs): + def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs): """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. + NOTE: This method will train all STATUS_PART_DONE tasks in task pool, not only the ``recs``. + Args: recs (list): a list of Recorder, the tasks have been saved to them. end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. Returns: @@ -356,13 +397,23 @@ class DelayTrainerRM(TrainerRM): if end_train_func is None: end_train_func = self.end_train_func + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + tasks = [] + for rec in recs: + tasks.append(rec.load_object("task")) + run_task( end_train_func, - self.task_pool, - experiment_name=self.experiment_name, + task_pool, + tasks=tasks, + experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) for rec in recs: - rec.set_tags(**{"train_status": "end_task_train"}) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 3ebc6fc1c..8583e946f 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -732,7 +732,7 @@ def flatten_dict(d, parent_key="", sep="."): """ items = [] for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k + new_key = parent_key + sep + str(k) if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 52d326c2a..9c5fc9ac2 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -3,6 +3,7 @@ from pathlib import Path import pickle +import dill from typing import Union @@ -14,6 +15,8 @@ class Serializable: - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk """ + pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python. + def __init__(self): self._dump_all = False self._exclude = [] @@ -74,4 +77,35 @@ class Serializable: def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None): self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: - pickle.dump(self, f) + if self.pickle_backend == "pickle": + pickle.dump(self, f) + elif self.pickle_backend == "dill": + dill.dump(self, f) + else: + raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + + @classmethod + def load(cls, filepath): + """ + load the collector from a file + + Args: + filepath (str): the path of file + + Raises: + TypeError: the pickled file must be `Collector` + + Returns: + Collector: the instance of Collector + """ + with open(filepath, "rb") as f: + if cls.pickle_backend == "pickle": + object = pickle.load(f) + elif cls.pickle_backend == "dill": + object = dill.load(f) + else: + raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + if isinstance(object, cls): + return object + else: + raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!") diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 6c62fbce9..a282865e6 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -12,15 +12,17 @@ This module also provide a method to simulate `Online Strategy <#Online Strategy Which means you can verify your strategy or find a better one. """ -from typing import Dict, List, Union +from typing import Callable, Dict, List, Union import pandas as pd from qlib import get_module_logger from qlib.data.data import D -from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble +from qlib.model.ens.ensemble import AverageEnsemble +from qlib.model.trainer import DelayTrainerR, Trainer +from qlib.utils import flatten_dict from qlib.utils.serial import Serializable from qlib.workflow.online.strategy import OnlineStrategy -from qlib.workflow.task.collect import HyperCollector +from qlib.workflow.task.collect import MergeCollector class OnlineManager(Serializable): @@ -32,6 +34,7 @@ class OnlineManager(Serializable): def __init__( self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], + trainer: Trainer = None, begin_time: Union[str, pd.Timestamp] = None, freq="day", need_log=True, @@ -43,6 +46,7 @@ class OnlineManager(Serializable): Args: strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date. + trainer (Trainer): the trainer to train task. None for using DelayTrainerR. freq (str, optional): data frequency. Defaults to "day". need_log (bool, optional): print log or not. Defaults to True. """ @@ -56,96 +60,166 @@ class OnlineManager(Serializable): begin_time = D.calendar(freq=self.freq).max() self.begin_time = pd.Timestamp(begin_time) self.cur_time = self.begin_time - self.history = {} + # The history of online models, which is a dict like {begin_time, {strategy, [online_models]}} + # begin_time means when online_models are onlined + self.history = {} + if trainer is None: + trainer = DelayTrainerR() + self.trainer = trainer + self.signals = None - def first_train(self): + def first_train(self, strategies:List[OnlineStrategy]=None, model_kwargs: dict = {}): """ - Run every strategy first_train method and record the online history. + Get tasks from every strategy's first_tasks method and train them. + If using DelayTrainer, it can finish training all together after every strategy's first_tasks. + + Args: + strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies. + model_kwargs (dict): the params for `prepare_online_models` """ - for strategy in self.strategy: + models_list = [] + if strategies is None: + strategies = self.strategy + for strategy in strategies: self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") - online_models = strategy.first_train() - self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models + tasks = strategy.first_tasks() + models = self.trainer.train(tasks, experiment_name=strategy.name_id) + models_list.append(models) - def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}): + for strategy, models in zip(strategies, models_list): + self.prepare_online_models(strategy, models, model_kwargs=model_kwargs) + + def routine( + self, + cur_time: Union[str, pd.Timestamp] = None, + delay: bool = False, + task_kwargs: dict = {}, + model_kwargs: dict = {}, + signal_kwargs: dict = {}, + ): """ Run typical update process for every strategy and record the online history. The typical update process after a routine, such as day by day or month by month. The process is: Prepare signals -> Prepare tasks -> Prepare online models. + If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks. + Args: cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None. + delay (bool): if delay prepare signals and models task_kwargs (dict): the params for `prepare_tasks` model_kwargs (dict): the params for `prepare_online_models` + signal_kwargs (dict): the params for `prepare_signals` """ if cur_time is None: cur_time = D.calendar(freq=self.freq).max() self.cur_time = pd.Timestamp(cur_time) # None for latest date + models_list = [] for strategy in self.strategy: + if not delay: + strategy.tool.update_online_pred() if self.need_log: self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") - if not strategy.trainer.is_delay(): - strategy.prepare_signals() - tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) - online_models = strategy.prepare_online_models(tasks, **model_kwargs) - if len(online_models) > 0: - self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models - def get_collector(self) -> HyperCollector: + tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) + models = self.trainer.train(tasks) + models_list.append(models) + + if not delay: + self.prepare_signals(**signal_kwargs) + + for strategy, models in zip(self.strategy, models_list): + self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs) + + def prepare_online_models( + self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {} + ): + """ + Prepare online model for strategy, including end_train, reset_online_tag and add history. + + Args: + strategy (OnlineStrategy): the instance of strategy. + models (list): a list of models. + delay (bool, optional): if delay prepare models. Defaults to False. + model_kwargs (dict, optional): the params for `prepare_online_models`. + """ + if not delay: + models = self.trainer.end_train(models, experiment_name=strategy.name_id) + online_models = strategy.prepare_online_models(models, **model_kwargs) + else: + # just set every models as online models temporarily before ``prepare_online_models`` + online_models = models + if len(online_models) > 0: + strategy.tool.reset_online_tag(online_models) + self.history.setdefault(self.cur_time, {})[strategy] = online_models + + def get_collector(self) -> MergeCollector: """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. Returns: - HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable). + MergeCollector: the collector to merge other collectors. """ collector_dict = {} for strategy in self.strategy: collector_dict[strategy.name_id] = strategy.get_collector() - return HyperCollector(collector_dict, process_list=SingleKeyEnsemble()) + return MergeCollector(collector_dict, process_list=[]) - def get_online_history(self, strategy_name_id: str) -> list: + def add_strategy(self, strategy: Union[OnlineStrategy, List[OnlineStrategy]]): """ - Get the online history based on strategy_name_id. + Add some new strategies to online manager. Args: - strategy_name_id (str): the name_id of strategy - - Returns: - list: a list like [(begin_time, [online_models])] + strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy """ - history_dict = self.history[strategy_name_id] - history = [] - for time in sorted(history_dict): - models = history_dict[time] - history.append((time, models)) - return history + if not isinstance(strategy, list): + strategy = [strategy] + self.first_train(strategy) + self.strategy.extend(strategy) - def delay_prepare(self, delay_kwargs={}): + def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False): """ - Prepare all models and signals if there are something waiting for prepare. + After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. + + NOTE: Given a set prediction, all signals before these prediction end time will be prepared well. + + Even if the latest signal already exists, the latest calculation result will be overwritten. + + .. note:: + + Given a prediction of a certain time, all signals before this time will be prepared well. Args: - delay_kwargs: the params for `delay_prepare` - """ - for strategy in self.strategy: - strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs) - - def get_signals(self) -> pd.DataFrame: - """ - Average all strategy signals as the online signals. - - Assumption: the signals from every strategy is pd.DataFrame. Override this function to change. + prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results after mergecollector must be {xxx:pred}. + over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False. Returns: - pd.DataFrame: signals + pd.DataFrame: the signals. """ - signals_dict = {} - for strategy in self.strategy: - signals_dict[strategy.name_id] = strategy.get_signals() - return AverageEnsemble()(signals_dict) + signals = prepare_func(self.get_collector()()) + old_signals = self.signals + if old_signals is not None and not over_write: + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + if self.need_log: + self.logger.info(f"Finished preparing new {len(new_signals)} signals.") + self.signals = signals + return new_signals - def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector: + def get_signals(self) -> pd.Series: + """ + Get prepared online signals. + + Returns: + pd.Series: signals + """ + return self.signals + + def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}): """ Starting from current time, this method will simulate every routine in OnlineManager until end time. @@ -153,6 +227,13 @@ class OnlineManager(Serializable): The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``. + Args: + end_time: the time the simulation will end + frequency: the calendar frequency + task_kwargs (dict): the params for `prepare_tasks` + model_kwargs (dict): the params for `prepare_online_models` + signal_kwargs (dict): the params for `prepare_signals` + Returns: HyperCollector: the OnlineManager's collector """ @@ -160,18 +241,30 @@ class OnlineManager(Serializable): self.first_train() for cur_time in cal: self.logger.info(f"Simulating at {str(cur_time)}......") - self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs) - self.delay_prepare(delay_kwargs=delay_kwargs) + self.routine( + cur_time, + delay=self.trainer.is_delay(), + task_kwargs=task_kwargs, + model_kwargs=model_kwargs, + signal_kwargs=signal_kwargs, + ) + # delay prepare the models and signals + if self.trainer.is_delay(): + self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs) self.logger.info(f"Finished preparing signals") return self.get_collector() - def reset(self): + def delay_prepare(self, model_kwargs={}, signal_kwargs={}): """ - This method will reset all strategy! + Prepare all models and signals if there are something waiting for prepare. - **Be careful to use it.** + Args: + model_kwargs: the params for `prepare_online_models` + signal_kwargs: the params for `prepare_signals` """ - self.cur_time = self.begin_time - self.history = {} - for strategy in self.strategy: - strategy.reset() + for cur_time, strategy_models in self.history.items(): + self.cur_time = cur_time + for strategy, models in strategy_models.items(): + self.prepare_online_models(strategy, models, delay=False, model_kwargs=model_kwargs) + # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way. + self.prepare_signals(**signal_kwargs) diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 0cae11b7f..1184553bd 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -7,19 +7,14 @@ OnlineStrategy is a set of strategy for online serving. from copy import deepcopy from typing import List, Tuple, Union - -import pandas as pd from qlib.data.data import D from qlib.log import get_module_logger -from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import Trainer, TrainerR -from qlib.workflow import R from qlib.workflow.online.utils import OnlineTool, OnlineToolR from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector +from qlib.workflow.task.collect import Collector, RecorderCollector from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.utils import TimeAdjuster, list_recorders +from qlib.workflow.task.utils import TimeAdjuster class OnlineStrategy: @@ -27,7 +22,7 @@ class OnlineStrategy: OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared. """ - def __init__(self, name_id: str, trainer: Trainer = None, need_log=True): + def __init__(self, name_id: str, need_log=True): """ Init OnlineStrategy. This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training. @@ -38,34 +33,22 @@ class OnlineStrategy: need_log (bool, optional): print log or not. Defaults to True. """ self.name_id = name_id - self.trainer = trainer self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log - self.tool = OnlineTool() + self.tool = OnlineTool(need_log) - def prepare_signals(self, delay: bool = False): + def prepare_tasks(self, cur_time, **kwargs) -> List[dict]: """ - After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. - - NOTE: Given a set prediction, all signals before these prediction end time will be prepared well. - - Args: - delay: bool - If this method was called by `delay_prepare` - """ - raise NotImplementedError(f"Please implement the `prepare_signals` method.") - - def prepare_tasks(self, *args, **kwargs): - """ - After the end of a routine, check whether we need to prepare and train some new tasks. + After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest).. Return the new tasks waiting for training. You can find last online models by OnlineTool.online_models. """ raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - def prepare_online_models(self, tasks, check_func=None, **kwargs): + def prepare_online_models(self, models, cur_time=None, check_func=None, **kwargs): """ + A typically implementation, but maybe you will need old models by online_tool. Use trainer to train a list of tasks and set the trained model to `online`. NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them. @@ -78,64 +61,34 @@ class OnlineStrategy: **kwargs: will be passed to end_train which means will be passed to customized train method. """ - if check_func is None: - check_func = lambda x: True - online_models = [] - if len(tasks) > 0: - new_models = self.trainer.train(tasks, **kwargs) - for model in new_models: - if check_func(model): + if check_func is not None: + online_models = [] + for model in models: + if check_func(model, cur_time): online_models.append(model) - self.tool.reset_online_tag(online_models) - return online_models + models = online_models + self.tool.reset_online_tag(models) + return models - def first_train(self): + def first_tasks(self) -> List[dict]: """ - Train a series of models firstly and set some of them as online models. + Generate a series of tasks firstly and return them. """ - raise NotImplementedError(f"Please implement the `first_train` method.") + raise NotImplementedError(f"Please implement the `first_tasks` method.") def get_collector(self) -> Collector: """ - Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving. - + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy. For example: 1) collect predictions in Recorder - 2) collect signals in .txt file + 2) collect signals in a txt file Returns: Collector """ raise NotImplementedError(f"Please implement the `get_collector` method.") - def delay_prepare(self, history: list, **kwargs): - """ - Prepare all models and signals if there are something waiting for prepare. - - Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way. - - Args: - history (list): an online models list likes [begin_time:[online models]]. - **kwargs: will be passed to end_train which means will be passed to customized train method. - """ - for begin_time, recs_list in history: - self.trainer.end_train(recs_list, **kwargs) - self.tool.reset_online_tag(recs_list) - self.prepare_signals(delay=True) - - def get_signals(self): - """ - Get prepared signals. - """ - raise NotImplementedError(f"Please implement the `get_signals` method.") - - def reset(self): - """ - Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation. - """ - pass - class RollingAverageStrategy(OnlineStrategy): @@ -148,9 +101,7 @@ class RollingAverageStrategy(OnlineStrategy): name_id: str, task_template: Union[dict, List[dict]], rolling_gen: RollingGen, - trainer: Trainer = None, need_log=True, - signal_exp_name="OnlineManagerSignals", ): """ Init RollingAverageStrategy. @@ -161,22 +112,16 @@ class RollingAverageStrategy(OnlineStrategy): name_id (str): a unique name or id. Will be also the name of Experiment. task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. rolling_gen (RollingGen): an instance of RollingGen - trainer (Trainer, optional): a instance of Trainer. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. - signal_exp_path (str): a specific experiment to save signals of different experiment. """ - super().__init__(name_id=name_id, trainer=trainer, need_log=need_log) + super().__init__(name_id=name_id, need_log=need_log) self.exp_name = self.name_id if not isinstance(task_template, list): task_template = [task_template] self.task_template = task_template - self.signal_exp_name = signal_exp_name self.rg = rolling_gen - self.tool = OnlineToolR(self.exp_name) + self.tool = OnlineToolR(self.exp_name, need_log) self.ta = TimeAdjuster() - with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() # the recorder to record signals - self.signal_rec.save_objects(**{"signals": None}) def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): """ @@ -209,18 +154,17 @@ class RollingAverageStrategy(OnlineStrategy): return artifacts_collector - def first_train(self) -> List[Recorder]: + def first_tasks(self) -> List[dict]: """ - Use rolling_gen to generate different tasks based on task_template and trained them. + Use rolling_gen to generate different tasks based on task_template. Returns: - List[Recorder]: a list of Recorder. + List[dict]: a list of tasks """ - tasks = task_generator( + return task_generator( tasks=self.task_template, generators=self.rg, # generate different date segment ) - return self.prepare_online_models(tasks) def prepare_tasks(self, cur_time) -> List[dict]: """ @@ -255,57 +199,6 @@ class RollingAverageStrategy(OnlineStrategy): return new_tasks return [] - def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame: - """ - Average the predictions of online models and offer a trading signals every routine. - The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` - Even if the latest signal already exists, the latest calculation result will be overwritten. - - .. note:: - - Given a prediction of a certain time, all signals before this time will be prepared well. - - Args: - over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. - Returns: - pd.DataFrame: the signals. - """ - if not delay: - self.tool.update_online_pred() - - # Get a collector to average online models predictions - online_collector = self.get_collector( - process_list=[AverageEnsemble()], - rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False, - artifacts_key="pred", - ) - online_results = online_collector() - signals = online_results["pred"] - - old_signals = self.get_signals() - if old_signals is not None and not over_write: - old_max = old_signals.index.get_level_values("datetime").max() - new_signals = signals.loc[old_max:] - signals = pd.concat([old_signals, new_signals], axis=0) - else: - new_signals = signals - if self.need_log: - self.logger.info( - f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}." - ) - self.signal_rec.save_objects(**{"signals": signals}) - return signals - - def get_signals(self) -> object: - """ - Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) - - Returns: - object: signals - """ - signals = self.signal_rec.load_object("signals") - return signals - def _list_latest(self, rec_list: List[Recorder]): """ List latest recorder form rec_list @@ -324,16 +217,3 @@ class RollingAverageStrategy(OnlineStrategy): if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: latest_rec.append(rec) return latest_rec, max_test - - def reset(self): - """ - NOTE: This method will delete all recorder in Experiment and reset the Trainer! - """ - self.trainer.reset() - # delete models - exp = R.get_exp(experiment_name=self.exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) - # delete signals - for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False): - exp.delete_recorder(rid) diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 296ca3ea6..c79a5dc00 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -17,7 +17,7 @@ from qlib.workflow.task.utils import list_recorders class OnlineTool: """ - OnlineTool. + OnlineTool will manage `online` models in an experiment which includes the models recorder. """ ONLINE_KEY = "online_status" # the online status key in recorder @@ -92,7 +92,7 @@ class OnlineToolR(OnlineTool): The implementation of OnlineTool based on (R)ecorder. """ - def __init__(self, experiment_name: str, need_log=True): + def __init__(self, experiment_name:str, need_log=True): """ Init OnlineToolR. diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 28320e2ce..b40ee0164 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,14 +5,16 @@ Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on. """ -from qlib.model.ens.ensemble import SingleKeyEnsemble +from typing import Callable, Dict, List +from qlib.utils.serial import Serializable from qlib.workflow import R -import dill as pickle -class Collector: +class Collector(Serializable): """The collector to collect different results""" + pickle_backend = "dill" # use dill to dump user method + def __init__(self, process_list=[]): """ Args: @@ -74,65 +76,42 @@ class Collector: collected = self.collect() return self.process_collect(collected, self.process_list, *args, **kwargs) - def save(self, filepath): - """ - save the collector into a file - Args: - filepath (str): the path of file - - Returns: - bool: if succeeded - """ - try: - with open(filepath, "wb") as f: - pickle.dump(self, f) - except Exception: - return False - return True - - @staticmethod - def load(filepath): - """ - load the collector from a file - - Args: - filepath (str): the path of file - - Raises: - TypeError: the pickled file must be `Collector` - - Returns: - Collector: the instance of Collector - """ - with open(filepath, "rb") as f: - collector = pickle.load(f) - if isinstance(collector, Collector): - return collector - else: - raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!") - - -class HyperCollector(Collector): +class MergeCollector(Collector): """ A collector to collect the results of other Collectors + + For example: + + We have 2 collector, which named A and B. + A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}. + Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}} + + ...... + """ - def __init__(self, collector_dict, process_list=[]): + def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []): """ Args: - collector_dict (dict): the dict like {collector_key, Collector} - process_list (list or Callable): the list of processors or the instance of processor to process dict. - NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict. - This can make result more readable. If you want to maintain as it should be, just give a empty process list. + collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector} + process_list (List[Callable]): the list of processors or the instance of processor to process dict. """ super().__init__(process_list=process_list) self.collector_dict = collector_dict def collect(self) -> dict: + """ + Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key) + + Returns: + dict: the dict after collecting. + """ collect_dict = {} - for key, collector in self.collector_dict.items(): - collect_dict[key] = collector() + for collector_key, collector in self.collector_dict.items(): + tmp_dict = collector() + for key, value in tmp_dict.items(): + collect_dict[collector_key + "_" + str(key)] = value return collect_dict @@ -145,7 +124,7 @@ class RecorderCollector(Collector): process_list=[], rec_key_func=None, rec_filter_func=None, - artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, + artifacts_path={"pred": "pred.pkl"}, artifacts_key=None, ): """init RecorderCollector @@ -203,7 +182,11 @@ class RecorderCollector(Collector): if self.ART_KEY_RAW == key: artifact = rec else: - artifact = rec.load_object(self.artifacts_path[key]) + # only collect existing artifact + try: + artifact = rec.load_object(self.artifacts_path[key]) + except Exception: + continue collect_dict.setdefault(key, {})[rec_key] = artifact return collect_dict diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index c4c6bab7f..7e08c76f4 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -5,7 +5,7 @@ Task generator can generate many tasks based on TaskGen and some task templates. """ import abc import copy -import typing +from typing import List, Union, Callable from .utils import TimeAdjuster @@ -64,7 +64,7 @@ class TaskGen(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def generate(self, task: dict) -> typing.List[dict]: + def generate(self, task: dict) -> List[dict]: """ generate different tasks based on a task template @@ -87,11 +87,34 @@ class TaskGen(metaclass=abc.ABCMeta): return self.generate(*args, **kwargs) +def handler_mod(task: dict, rg): + """ + Help to modify the handler end time when using RollingGen + + Args: + task (dict): a task template + rg (RollingGen): an instance of RollingGen + """ + try: + interval = rg.ta.cal_interval( + task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], + task["dataset"]["kwargs"]["segments"][rg.test_key][1], + ) + # if end_time < the end of test_segments, then change end_time to allow load more data + if interval < 0: + task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy( + task["dataset"]["kwargs"]["segments"][rg.test_key][1] + ) + except KeyError: + # Maybe dataset do not have handler, then do nothing. + pass + + class RollingGen(TaskGen): ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date - def __init__(self, step: int = 40, rtype: str = ROLL_EX, modify_end_time=True): + def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod): """ Generate tasks for rolling @@ -101,19 +124,19 @@ class RollingGen(TaskGen): step to rolling rtype : str rolling type (expanding, sliding) - modify_end_time: bool - Whether the data set configuration needs to be modified when the required scope exceeds the original data set scope + ds_extra_mod_func: Callable + A method like: handler_mod(task: dict, rg: RollingGen) + Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of handler of dataset. """ self.step = step self.rtype = rtype - self.modify_end_time = modify_end_time - # TODO: Ask pengrong to update future date in dataset + self.ds_extra_mod_func = ds_extra_mod_func self.ta = TimeAdjuster(future=True) self.test_key = "test" self.train_key = "train" - def generate(self, task: dict) -> typing.List[dict]: + def generate(self, task: dict) -> List[dict]: """ Converting the task into a rolling task. @@ -200,18 +223,8 @@ class RollingGen(TaskGen): # update segments of this task t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) - - try: - interval = self.ta.cal_interval( - t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], - t["dataset"]["kwargs"]["segments"][self.test_key][1], - ) - # if end_time < the end of test_segments, then change end_time to allow load more data - if self.modify_end_time and interval < 0: - t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) - except KeyError: - # Maybe the user dataset has no handler or end_time - pass prev_seg = segments + if self.ds_extra_mod_func is not None: + self.ds_extra_mod_func(t, self) res.append(t) return res diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index c71be7d39..025dfa85c 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -388,6 +388,7 @@ class TaskManager: def run_task( task_func: Callable, task_pool: str, + tasks: List[dict] = None, force_release: bool = False, before_status: str = TaskManager.STATUS_WAITING, after_status: str = TaskManager.STATUS_DONE, @@ -413,6 +414,8 @@ def run_task( the function to run the task task_pool : str the name of the task pool (Collection in MongoDB) + tasks: List[dict] + will only train these tasks config, None for train all tasks. force_release : bool will the program force to release the resource before_status : str: @@ -425,9 +428,12 @@ def run_task( tm = TaskManager(task_pool) ever_run = False + query = {} + if tasks is not None: + query = {"filter": {"$in": tasks}} while True: - with tm.safe_fetch_task(status=before_status) as task: + with tm.safe_fetch_task(status=before_status, query=query) as task: if task is None: break get_module_logger("run_task").info(task["def"])