From bd7a1c11b981099cdbe9a69429e3566be36854be Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Fri, 2 Apr 2021 04:27:14 +0000 Subject: [PATCH] trainer & group & collect & ensemble --- .../model_rolling/task_manager_rolling.py | 35 ++-- .../task_manager_rolling_with_updating.py | 8 +- .../update_online_pred.py | 6 +- qlib/model/ens/ensemble.py | 98 ++++++++++ qlib/model/ens/group.py | 68 +++++++ qlib/model/trainer.py | 68 +++++++ qlib/workflow/online/__init__.py | 0 .../{task/online.py => online/manager.py} | 33 ++-- qlib/workflow/{task => online}/update.py | 4 +- qlib/workflow/task/collect.py | 58 +++--- qlib/workflow/task/ensemble.py | 176 ------------------ qlib/workflow/task/manage.py | 26 +-- 12 files changed, 319 insertions(+), 261 deletions(-) rename examples/{online_svr => online_srv}/task_manager_rolling_with_updating.py (97%) rename examples/{online_svr => online_srv}/update_online_pred.py (90%) create mode 100644 qlib/model/ens/ensemble.py create mode 100644 qlib/model/ens/group.py create mode 100644 qlib/workflow/online/__init__.py rename qlib/workflow/{task/online.py => online/manager.py} (87%) rename qlib/workflow/{task => online}/update.py (98%) delete mode 100644 qlib/workflow/task/ensemble.py diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 75d360fa1..3eb05de72 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -8,9 +8,11 @@ from qlib.workflow import R from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector -from qlib.workflow.task.ensemble import RollingEnsemble +from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow import pandas as pd from qlib.workflow.task.utils import list_recorders +from qlib.model.ens.group import RollingGroup +from qlib.model.trainer import TrainerRM data_handler_config = { "start_time": "2008-01-01", @@ -94,24 +96,16 @@ def task_generating(): return tasks -# This part corresponds to "Task Storing" in the document -def task_storing(tasks, task_pool, exp_name): - print("========== task_storing ==========") - tm = TaskManager(task_pool=task_pool) - tm.create_task(tasks) # all tasks will be saved to MongoDB - - -# This part corresponds to "Task Running" in the document -def task_running(task_pool, exp_name): - print("========== task_running ==========") - run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method +def task_training(tasks, task_pool, exp_name): + trainer = TrainerRM() + trainer.train(tasks, exp_name, task_pool) # This part corresponds to "Task Collecting" in the document def task_collecting(task_pool, exp_name): print("========== task_collecting ==========") - def get_group_key_func(recorder): + def rec_key(recorder): task_config = recorder.load_object("task") model_key = task_config["model"]["class"] rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] @@ -119,14 +113,14 @@ def task_collecting(task_pool, exp_name): def my_filter(recorder): # only choose the results of "LGBModel" - model_key, rolling_key = get_group_key_func(recorder) + model_key, rolling_key = rec_key(recorder) if model_key == "LGBModel": return True return False - collector = RecorderCollector(exp_name) - # group tasks by "get_task_key" and filter tasks by "my_filter" - artifact = collector.collect(RollingEnsemble(), get_group_key_func, rec_filter_func=my_filter) + artifact = ens_workflow( + RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter + ) print(artifact) @@ -143,10 +137,9 @@ def main( } qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) - reset(task_pool, exp_name) - tasks = task_generating() - task_storing(tasks, task_pool, exp_name) - task_running(task_pool, exp_name) + # reset(task_pool, exp_name) + # tasks = task_generating() + # task_training(tasks, task_pool, exp_name) task_collecting(task_pool, exp_name) diff --git a/examples/online_svr/task_manager_rolling_with_updating.py b/examples/online_srv/task_manager_rolling_with_updating.py similarity index 97% rename from examples/online_svr/task_manager_rolling_with_updating.py rename to examples/online_srv/task_manager_rolling_with_updating.py index fff470c86..32f582b4c 100644 --- a/examples/online_svr/task_manager_rolling_with_updating.py +++ b/examples/online_srv/task_manager_rolling_with_updating.py @@ -6,10 +6,10 @@ from qlib.config import REG_CN from qlib.model.trainer import task_train from qlib.workflow import R from qlib.workflow.task.collect import RecorderCollector -from qlib.workflow.task.ensemble import RollingEnsemble +from qlib.model.ens.ensemble import RollingEnsemble from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager, run_task -from qlib.workflow.task.online import RollingOnlineManager +from qlib.workflow.online.manager import RollingOnlineManager from qlib.workflow.task.utils import list_recorders data_handler_config = { @@ -155,10 +155,10 @@ def first_run(): rolling_online_manager.reset_online_tag(latest_rec.values()) -def after_day(): +def routine(): print("========== after_day ==========") print_online_model() - rolling_online_manager.after_day() + rolling_online_manager.routine() print_online_model() task_collecting() diff --git a/examples/online_svr/update_online_pred.py b/examples/online_srv/update_online_pred.py similarity index 90% rename from examples/online_svr/update_online_pred.py rename to examples/online_srv/update_online_pred.py index ac86b48e8..7bce82ac8 100644 --- a/examples/online_svr/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -2,7 +2,7 @@ import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train -from qlib.workflow.task.online import OnlineManagerR +from qlib.workflow.online.manager import OnlineManagerR from qlib.workflow.task.utils import list_recorders data_handler_config = { @@ -52,7 +52,7 @@ task = { } -def first_train(experiment_name="online_svr"): +def first_train(experiment_name="online_srv"): rid = task_train(task_config=task, experiment_name=experiment_name) @@ -60,7 +60,7 @@ def first_train(experiment_name="online_svr"): online_manager.reset_online_tag(rid) -def update_online_pred(experiment_name="online_svr"): +def update_online_pred(experiment_name="online_srv"): online_manager = OnlineManagerR(experiment_name) diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py new file mode 100644 index 000000000..dcc4ba5d3 --- /dev/null +++ b/qlib/model/ens/ensemble.py @@ -0,0 +1,98 @@ +from abc import abstractmethod +from typing import Callable, Union + +import pandas as pd +from qlib.workflow.task.collect import Collector + + +def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs): + """the ensemble workflow based on collector and different dict processors. + + Args: + collector (Collector): the collector to collect the result into {result_key: things} + process_list (list or Callable): the list of processors or the instance of processor to process dict. + The processor order is same as the list order. + + For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] + + artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + + Returns: + dict: the ensemble dict + """ + collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func) + if not isinstance(process_list, list): + process_list = [process_list] + + ensemble = {} + for artifact in collect_dict: + value = collect_dict[artifact] + for process in process_list: + if not callable(process): + raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.") + value = process(value, *args, **kwargs) + ensemble[artifact] = value + + return ensemble + + +class Ensemble: + """Merge the objects in an Ensemble.""" + + def __init__(self, merge_func=None): + """init Ensemble + + Args: + merge_func (Callable, optional): Given a dict and return the ensemble. + + For example: {Rollinga_b: object, Rollingb_c: object} -> object + + Defaults to None. + """ + self._merge = merge_func + + def __call__(self, ensemble_dict: dict, *args, **kwargs): + """Merge the ensemble_dict into an ensemble object. + + Args: + ensemble_dict (dict): the ensemble dict waiting for merging like {name: things} + + Returns: + object: the ensemble object + """ + if isinstance(getattr(self, "_merge", None), Callable): + return self._merge(ensemble_dict, *args, **kwargs) + else: + raise NotImplementedError(f"Please specify valid merge_func.") + + +class RollingEnsemble(Ensemble): + + """Merge the rolling objects in an Ensemble""" + + @staticmethod + def rolling_merge(rolling_dict: dict): + """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. + + NOTE: The values of dict must be pd.Dataframe, and have the index "datetime" + + Args: + rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}. + The key of the dict will be ignored. + + Returns: + pd.Dataframe: the complete result of rolling. + """ + artifact_list = list(rolling_dict.values()) + artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) + artifact = pd.concat(artifact_list) + # If there are duplicated predition, use the latest perdiction + artifact = artifact[~artifact.index.duplicated(keep="last")] + artifact = artifact.sort_index() + return artifact + + def __init__(self, merge_func=None): + super().__init__(merge_func=merge_func) + if merge_func is None: + self._merge = RollingEnsemble.rolling_merge \ No newline at end of file diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py new file mode 100644 index 000000000..1ef3da77f --- /dev/null +++ b/qlib/model/ens/group.py @@ -0,0 +1,68 @@ +from qlib.model.ens.ensemble import Ensemble, RollingEnsemble +from typing import Callable, Union + + +class Group: + """Group the objects based on dict""" + + def __init__(self, group_func=None, ens: Ensemble = None): + """init Group. + + Args: + group_func (Callable, optional): Given a dict and return the group key and one of group elements. + + For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} + + Defaults to None. + + ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping. + """ + self._group = group_func + self._ens = ens + + def __call__(self, ungrouped_dict: dict, *args, **kwargs): + """Group the ungrouped_dict into different groups. + + Args: + ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things} + + Returns: + dict: grouped_dict like {G1: object, G2: object} + """ + if isinstance(getattr(self, "_group", None), Callable): + grouped_dict = self._group(ungrouped_dict, *args, **kwargs) + if self._ens is not None: + ens_dict = {} + for key, value in grouped_dict.items(): + ens_dict[key] = self._ens(value) + grouped_dict = ens_dict + return grouped_dict + else: + raise NotImplementedError(f"Please specify valid merge_func.") + + +class RollingGroup(Group): + """group the rolling dict""" + + @staticmethod + def rolling_group(rolling_dict: dict): + """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}} + + NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly. + + Args: + rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing. + + Returns: + dict: grouped dict + """ + grouped_dict = {} + for key, values in rolling_dict.items(): + if isinstance(key, tuple): + grouped_dict.setdefault(key[:-1], {})[key[-1]] = values + return grouped_dict + + def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()): + super().__init__(group_func=group_func, ens=ens) + if group_func is None: + self._group = RollingGroup.rolling_group \ No newline at end of file diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 45650c0c7..e128e700d 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -4,6 +4,7 @@ from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord +from qlib.workflow.task.manage import TaskManager, run_task def task_train(task_config: dict, experiment_name: str) -> str: @@ -57,3 +58,70 @@ def task_train(task_config: dict, experiment_name: str) -> str: ar.generate() return recorder + + +class Trainer: + """ + The trainer which can train a list of model + """ + + def train(self, *args, **kwargs): + """Given a list of model definition, finished training and return the results of them. + + Returns: + list: a list of trained results + """ + raise NotImplementedError(f"Please implement the `train` method.") + + +class TrainerR(Trainer): + """Trainer based on (R)ecorder. + + Assumption: models were defined by `task` and the results will saved to `Recorder` + """ + + def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs): + """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + + Args: + tasks (list): a list of definition based on `task` dict + experiment_name (str): the experiment name + train_func (Callable): the train method which need at least `task` and `experiment_name` + + Returns: + list: a list of Recorders + """ + recs = [] + for task in tasks: + recs.append(train_func(task, experiment_name, *args, **kwargs)) + return recs + + +class TrainerRM(TrainerR): + """Trainer based on (R)ecorder and Task(M)anager + + Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager + """ + + def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs): + """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + + This method defaults to a single process, but TaskManager offered a great way to parallel training. + Users can customize their train_func to realize multiple processes or even multiple machines. + + Args: + tasks (list): a list of definition based on `task` dict + experiment_name (str): the experiment name + train_func (Callable): the train method which need at least `task` and `experiment_name` + + Returns: + list: a list of Recorders + """ + tm = TaskManager(task_pool=task_pool) + _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB + run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs) + + recs = [] + for _id in _id_list: + recs.append(tm.re_query(_id)["res"]) + return recs \ No newline at end of file diff --git a/qlib/workflow/online/__init__.py b/qlib/workflow/online/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/workflow/task/online.py b/qlib/workflow/online/manager.py similarity index 87% rename from qlib/workflow/task/online.py rename to qlib/workflow/online/manager.py index f7ffbd18a..fbee0d707 100644 --- a/qlib/workflow/task/online.py +++ b/qlib/workflow/online/manager.py @@ -3,7 +3,7 @@ from qlib import get_module_logger from qlib.workflow import R from qlib.model.trainer import task_train from qlib.workflow.recorder import MLflowRecorder, Recorder -from qlib.workflow.task.update import ModelUpdater +from qlib.workflow.online.update import ModelUpdater from qlib.workflow.task.utils import TimeAdjuster from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager @@ -37,6 +37,16 @@ class OnlineManager(Serializable): def get_online_tag(self, *args, **kwargs): raise NotImplementedError(f"Please implement the `get_online_tag` method.") + def reset_online_tag(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `reset_online_tag` method.") + + def routine(self, *args, **kwargs): + self.prepare_signals(*args, **kwargs) + self.prepare_tasks(*args, **kwargs) + self.prepare_new_models(*args, **kwargs) + self.update_online_pred(*args, **kwargs) + self.reset_online_tag(*args, **kwargs) + class OnlineManagerR(OnlineManager): """ @@ -86,21 +96,18 @@ class OnlineManagerR(OnlineManager): cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG) self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.") - def after_day(self, *args, **kwargs): - self.prepare_signals(*args, **kwargs) - self.prepare_tasks(*args, **kwargs) - self.prepare_new_models(*args, **kwargs) - self.update_online_pred(*args, **kwargs) - self.reset_online_tag() - class RollingOnlineManager(OnlineManagerR): - def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None: + # FIXME: TaskManager不应该与onlinemanager强耦合 + def __init__( + self, experiment_name: str, rolling_gen: RollingGen, task_manager: TaskManager, trainer=run_task + ) -> None: super().__init__(experiment_name) self.ta = TimeAdjuster() self.rg = rolling_gen - self.tm = TaskManager(task_pool=task_pool) + self.tm = task_manager self.logger = get_module_logger(self.__class__.__name__) + self.trainer = trainer def prepare_signals(self): pass @@ -122,13 +129,13 @@ class RollingOnlineManager(OnlineManagerR): task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) old_tasks.append(task) new_tasks = task_generator(old_tasks, self.rg) - new_num = self.tm.create_task(new_tasks) - self.logger.info(f"Finished prepare {new_num} tasks.") + self.tm.create_task(new_tasks) def prepare_new_models(self): """prepare(train) new models based on online model""" - run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name) + run_task(task_train, task_pool=self.tm.task_pool, experiment_name=self.exp_name) latest_records, _ = self.list_latest_recorders() + # FIXME: 现有的流程,如果没有可更新的模型,仍会调用这个,导致会先将以前的模型设置成nextonline再去更新pred,但这个时候online已经没有了,pred无法更新 self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values()) self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.") diff --git a/qlib/workflow/task/update.py b/qlib/workflow/online/update.py similarity index 98% rename from qlib/workflow/task/update.py rename to qlib/workflow/online/update.py index 002f1128f..1a6897d02 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/online/update.py @@ -45,8 +45,8 @@ class ModelUpdater: """ segments = {"test": (start_time, end_time)} dataset = recorder.load_object("dataset") - dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}) - dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments) + dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments) + dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}) return dataset def update_pred(self, recorder: Recorder, frequency="day"): diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 91b713ef8..7e555ed06 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,49 +1,54 @@ from abc import abstractmethod from typing import Callable, Union - -import pandas as pd -from qlib import get_module_logger from qlib.workflow.task.utils import list_recorders class Collector: - """The collector to collect different results based on experiment backend and ensemble method""" + """The collector to collect different results""" - def collect(self, ensemble, get_group_key_func, *args, **kwargs): - """To collect the results, we need to get the experiment record firstly and divided them into - different groups. Then use ensemble methods to merge the group. + def collect(self, *args, **kwargs): + """Collect the results and return a dict like {key: things} - Args: - ensemble (Ensemble): an instance of Ensemble - get_group_key_func (Callable): a function to get the group of a experiment record + Returns: + dict: the dict after collected. + For example: + + {"prediction": pd.Series} + + {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}} + + ...... """ raise NotImplementedError(f"Please implement the `collect` method.") class RecorderCollector(Collector): - def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None: + def __init__( + self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None + ) -> None: """init RecorderCollector Args: exp_name (str): the name of Experiment artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. """ self.exp_name = exp_name self.artifacts_path = artifacts_path + if rec_key_func is None: + rec_key_func = lambda rec: rec.info["id"] + self._get_key = rec_key_func - def collect(self, ensemble, get_group_key_func, artifacts_key=None, rec_filter_func=None): - """Collect different artifacts based on recorder after filtering and ensemble method. - Group recorder by get_group_key_func. + def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func, + """Collect different artifacts based on recorder after filtering. Args: - ensemble (Ensemble): an instance of Ensemble - get_group_key_func (Callable): a function to get the group of a experiment record - artifacts_key (str or List, optional): the artifacts key you want to get. Defaults to None. + artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. Returns: - dict: the dict after collected. + dict: the dict after collected like {artifact: {rec_key: object}} """ if artifacts_key is None: artifacts_key = self.artifacts_path.keys() @@ -51,22 +56,13 @@ class RecorderCollector(Collector): if isinstance(artifacts_key, str): artifacts_key = [artifacts_key] - # prepare_ensemble - ensemble_dict = {} - for key in artifacts_key: - ensemble_dict.setdefault(key, {}) + collect_dict = {} # filter records recs_flt = list_recorders(self.exp_name, rec_filter_func) for _, rec in recs_flt.items(): - group_key = get_group_key_func(rec) + rec_key = self._get_key(rec) for key in artifacts_key: artifact = rec.load_object(self.artifacts_path[key]) - ensemble_dict[key][group_key] = artifact + collect_dict.setdefault(key, {})[rec_key] = artifact - if isinstance(artifacts_key, str): - return ensemble(ensemble_dict[artifacts_key]) - - collect_dict = {} - for key in artifacts_key: - collect_dict[key] = ensemble(ensemble_dict[key]) - return collect_dict + return collect_dict \ No newline at end of file diff --git a/qlib/workflow/task/ensemble.py b/qlib/workflow/task/ensemble.py deleted file mode 100644 index dca0dee3e..000000000 --- a/qlib/workflow/task/ensemble.py +++ /dev/null @@ -1,176 +0,0 @@ -from abc import abstractmethod -from typing import Callable, Union - -import pandas as pd -from qlib import get_module_logger -from qlib.workflow.task.utils import list_recorders -from typing import Dict - - -class Ensemble: - """Merge the objects in an Ensemble.""" - - def __init__(self, merge_func=None, get_grouped_key_func=None) -> None: - """init Ensemble - - Args: - merge_func (Callable, optional): The specific merge function. Defaults to None. - get_grouped_key_func (Callable, optional): Get group_inner_key and group_outer_key by group_key. Defaults to None. - """ - self.logger = get_module_logger(self.__class__.__name__) - if merge_func is not None: - self.merge_func = merge_func - if get_grouped_key_func is not None: - self.get_grouped_key_func = get_grouped_key_func - - def merge_func(self, group_inner_dict): - """Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object}, - merge it to object - - Args: - group_inner_dict (dict): the inner group dict - - """ - raise NotImplementedError(f"Please implement the `merge_func` method.") - - def get_grouped_key_func(self, group_key): - """Given a group_key and return the group_outer_key, group_inner_key. - - For example: - (A,B,Rolling) -> (A,B):Rolling - (A,B) -> C:(A,B) - - Args: - group_key (tuple or str): the group key - """ - raise NotImplementedError(f"Please implement the `get_grouped_key_func` method.") - - def group(self, group_dict: Dict[tuple or str, object]) -> Dict[tuple or str, Dict[tuple or str, object]]: - """In a group of dict, further divide them into outgroups and innergroup. - - For example: - - .. code-block:: python - - RollingEnsemble: - input: - { - (ModelA,Horizon5,Rollinga_b): object - (ModelA,Horizon5,Rollingb_c): object - (ModelA,Horizon10,Rollinga_b): object - (ModelA,Horizon10,Rollingb_c): object - (ModelB,Horizon5,Rollinga_b): object - (ModelB,Horizon5,Rollingb_c): object - (ModelB,Horizon10,Rollinga_b): object - (ModelB,Horizon10,Rollingb_c): object - } - - output: - { - (ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object} - (ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object} - (ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object} - (ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object} - } - - Args: - group_dict (Dict[tuple or str, object]): a group of dict - - Returns: - Dict[tuple or str, Dict[tuple or str, object]]: the dict after `group` - """ - grouped_dict = {} - for group_key, artifact in group_dict.items(): - group_outer_key, group_inner_key = self.get_grouped_key_func(group_key) # (A,B,Rolling) -> (A,B):Rolling - grouped_dict.setdefault(group_outer_key, {})[group_inner_key] = artifact - return grouped_dict - - def reduce(self, grouped_dict: dict): - """After grouping, reduce the innergroup. - - For example: - - .. code-block:: python - - RollingEnsemble: - input: - { - (ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object} - (ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object} - (ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object} - (ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object} - } - - output: - { - (ModelA,Horizon5): object - (ModelA,Horizon10): object - (ModelB,Horizon5): object - (ModelB,Horizon10): object - } - - Args: - grouped_dict (dict): the dict after `group` - - Returns: - dict: the dict after `reduce` - """ - reduce_group = {} - for group_outer_key, group_inner_dict in grouped_dict.items(): - artifact = self.merge_func(group_inner_dict) - reduce_group[group_outer_key] = artifact - return reduce_group - - def __call__(self, group_dict): - """The process of Ensemble is group it firstly and then reduce it, like MapReduce. - - Args: - group_dict (Dict[tuple or str, object]): a group of dict - - Returns: - dict: the dict after `reduce` - """ - grouped_dict = self.group(group_dict) - return self.reduce(grouped_dict) - - -class RollingEnsemble(Ensemble): - """A specific implementation of Ensemble for Rolling.""" - - def merge_func(self, group_inner_dict): - """merge group_inner_dict by datetime. - - Args: - group_inner_dict (dict): the inner group dict - - Returns: - object: the artifact after merging - """ - artifact_list = list(group_inner_dict.values()) - artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) - artifact = pd.concat(artifact_list) - # If there are duplicated predition, use the latest perdiction - artifact = artifact[~artifact.index.duplicated(keep="last")] - artifact = artifact.sort_index() - return artifact - - def get_grouped_key_func(self, group_key): - """The final axis of group_key must be the Rolling key. - When `collect`, get_group_key_func can add the statement below. - - .. code-block:: python - - def get_group_key_func(recorder): - task_config = recorder.load_object("task") - ...... - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return ......, rolling_key - - Args: - group_key (tuple or str): the group key - - Returns: - tuple or str, tuple or str: group_outer_key, group_inner_key - """ - assert len(group_key) >= 2 - return group_key[:-1], group_key[-1] diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index a62164207..ddd833aa4 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -60,7 +60,7 @@ class TaskManager: """ self.mdb = get_mongodb() self.task_pool = task_pool - self.logger = get_module_logger("TaskManager") + self.logger = get_module_logger(self.__class__.__name__) def list(self): return self.mdb.list_collection_names() @@ -105,10 +105,11 @@ class TaskManager: def insert_task(self, task, task_pool=None): task_pool = self._get_task_pool(task_pool) try: - task_pool.insert_one(task) + insert_result = task_pool.insert_one(task) except InvalidDocument: task["filter"] = self._dict_to_str(task["filter"]) - task_pool.insert_one(task) + insert_result = task_pool.insert_one(task) + return insert_result def insert_task_def(self, task_def, task_pool=None): """ @@ -133,7 +134,8 @@ class TaskManager: "status": self.STATUS_WAITING, } ) - self.insert_task(task, task_pool) + insert_result = self.insert_task(task, task_pool) + return insert_result def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False): """ @@ -151,8 +153,8 @@ class TaskManager: if print new task Returns ------- - int - the length of new tasks + list + a list of the _id of new tasks """ task_pool = self._get_task_pool(task_pool) new_tasks = [] @@ -163,7 +165,7 @@ class TaskManager: r = task_pool.find_one({"filter": self._dict_to_str(t)}) if r is None: new_tasks.append(t) - print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks)) + self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}") if print_nt: # print new task for t in new_tasks: @@ -172,10 +174,12 @@ class TaskManager: if dry_run: return + _id_list = [] for t in new_tasks: - self.insert_task_def(t, task_pool) + insert_result = self.insert_task_def(t, task_pool) + _id_list.append(insert_result.inserted_id) - return len(new_tasks) + return _id_list def fetch_task(self, query={}, task_pool=None): task_pool = self._get_task_pool(task_pool) @@ -248,9 +252,9 @@ class TaskManager: for t in task_pool.find(query): yield self._decode_task(t) - def re_query(self, task, task_pool=None): + def re_query(self, _id, task_pool=None): task_pool = self._get_task_pool(task_pool) - return task_pool.find_one({"_id": ObjectId(task["_id"])}) + return task_pool.find_one({"_id": ObjectId(_id)}) def commit_task_res(self, task, res, status=None, task_pool=None): task_pool = self._get_task_pool(task_pool)