diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 3eb05de72..907086487 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -97,8 +97,8 @@ def task_generating(): def task_training(tasks, task_pool, exp_name): - trainer = TrainerRM() - trainer.train(tasks, exp_name, task_pool) + trainer = TrainerRM(exp_name, task_pool) + trainer.train(tasks) # This part corresponds to "Task Collecting" in the document @@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name): return False artifact = ens_workflow( - RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter + RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(), ) print(artifact) @@ -128,7 +128,7 @@ def main( provider_uri="~/.qlib/qlib_data/cn_data", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", - exp_name="rolling_exp", + experiment_name="rolling_exp", task_pool="rolling_task", ): mongo_conf = { @@ -137,11 +137,13 @@ def main( } qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) - # reset(task_pool, exp_name) - # tasks = task_generating() - # task_training(tasks, task_pool, exp_name) - task_collecting(task_pool, exp_name) + reset(task_pool, experiment_name) + tasks = task_generating() + task_training(tasks, task_pool, experiment_name) + task_collecting(task_pool, experiment_name) if __name__ == "__main__": + ## to see the whole process with your own parameters, use the command below + # python update_online_pred.py main --experiment_name="your_exp_name" fire.Fire() diff --git a/examples/online_srv/task_manager_rolling_with_updating.py b/examples/online_srv/task_manager_rolling_with_updating.py index d8bd95927..5b80f9133 100644 --- a/examples/online_srv/task_manager_rolling_with_updating.py +++ b/examples/online_srv/task_manager_rolling_with_updating.py @@ -70,89 +70,106 @@ task_xgboost_config = { "record": record_config, } +class RollingOnlineExample: -def print_online_model(): - print("========== print_online_model ==========") - print("Current 'online' model:") - for rid, rec in list_recorders(exp_name).items(): - if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG: - print(rid) - print("Current 'next online' model:") - for rid, rec in list_recorders(exp_name).items(): - if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG: - print(rid) + def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550): + self.exp_name = exp_name + self.task_pool = task_pool + mongo_conf = { + "task_url": task_url, # your MongoDB url + "task_db_name": task_db_name, # database name + } + qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) + + self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) + self.trainer = TrainerRM(self.exp_name, self.task_pool) + self.task_manager = TaskManager(self.task_pool) + self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer) + + + + def print_online_model(self): + print("========== print_online_model ==========") + print("Current 'online' model:") + for rid, rec in list_recorders(self.exp_name).items(): + if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.ONLINE_TAG: + print(rid) + print("Current 'next online' model:") + for rid, rec in list_recorders(self.exp_name).items(): + if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG: + print(rid) -# This part corresponds to "Task Generating" in the document -def task_generating(): + # This part corresponds to "Task Generating" in the document + def task_generating(self): - print("========== task_generating ==========") + print("========== task_generating ==========") - tasks = task_generator( - tasks=[task_xgboost_config, task_lgb_config], - generators=rolling_gen, # generate different date segment - ) + tasks = task_generator( + tasks=[task_xgboost_config, task_lgb_config], + generators=self.rolling_gen, # generate different date segment + ) - pprint(tasks) + pprint(tasks) - return tasks + return tasks -def task_training(tasks): - trainer.train(tasks, exp_name, task_pool) + def task_training(self, tasks): + self.trainer.train(tasks) -# This part corresponds to "Task Collecting" in the document -def task_collecting(): - print("========== task_collecting ==========") + # This part corresponds to "Task Collecting" in the document + def task_collecting(self): + print("========== task_collecting ==========") - def rec_key(recorder): - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key - def my_filter(recorder): - # only choose the results of "LGBModel" - model_key, rolling_key = rec_key(recorder) - if model_key == "LGBModel": - return True - return False + def my_filter(recorder): + # only choose the results of "LGBModel" + model_key, rolling_key = rec_key(recorder) + if model_key == "LGBModel": + return True + return False - artifact = ens_workflow( - RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter - ) - print(artifact) + artifact = ens_workflow( + RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup() + ) + print(artifact) -# Reset all things to the first status, be careful to save important data -def reset(): - print("========== reset ==========") - task_manager.remove() - exp = R.get_exp(experiment_name=exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) + # Reset all things to the first status, be careful to save important data + def reset(self): + print("========== reset ==========") + self.task_manager.remove() + exp = R.get_exp(experiment_name=self.exp_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) -# Run this firstly to see the workflow in Task Management -def first_run(): - print("========== first_run ==========") - reset() + # Run this firstly to see the workflow in Task Management + def first_run(self): + print("========== first_run ==========") + self.reset() - tasks = task_generating() - task_training(tasks) - task_collecting() + tasks = self.task_generating() + self.task_training(tasks) + self.task_collecting() - latest_rec, _ = rolling_online_manager.list_latest_recorders() - rolling_online_manager.reset_online_tag(latest_rec.values()) + latest_rec, _ = self.rolling_online_manager.list_latest_recorders() + self.rolling_online_manager.reset_online_tag(list(latest_rec.values())) -def routine(): - print("========== routine ==========") - print_online_model() - rolling_online_manager.routine() - print_online_model() - task_collecting() + def routine(self): + print("========== routine ==========") + self.print_online_model() + self.rolling_online_manager.routine() + self.print_online_model() + self.task_collecting() if __name__ == "__main__": @@ -161,26 +178,7 @@ if __name__ == "__main__": ####### to update the models and predictions after the trading time, use the command below # python task_manager_rolling_with_updating.py after_day - - #################### you need to finish the configurations below ######################### - - provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir - mongo_conf = { - "task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url - "task_db_name": "rolling_db", # database name - } - qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) - - exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow - task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB - rolling_step = 550 - - ########################################################################################## - rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) - task_manager = TaskManager(task_pool=task_pool) - trainer = TrainerRM() - rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer - ) - - fire.Fire() + + ####### to define your own parameters, use `--` + # python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40 + fire.Fire(RollingOnlineExample) diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 7bce82ac8..84472bc3b 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -54,10 +54,10 @@ task = { def first_train(experiment_name="online_srv"): - rid = task_train(task_config=task, experiment_name=experiment_name) + rec = task_train(task_config=task, experiment_name=experiment_name) online_manager = OnlineManagerR(experiment_name) - online_manager.reset_online_tag(rid) + online_manager.reset_online_tag(rec) def update_online_pred(experiment_name="online_srv"): @@ -71,13 +71,17 @@ def update_online_pred(experiment_name="online_srv"): online_manager.update_online_pred() +def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"): + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + qlib.init(provider_uri=provider_uri, region=region) + first_train(experiment_name) + update_online_pred(experiment_name) if __name__ == "__main__": ## to train a model and set it to online model, use the command below # python update_online_pred.py first_train ## to update online predictions once a day, use the command below # python update_online_pred.py update_online_pred - - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - qlib.init(provider_uri=provider_uri, region=REG_CN) + ## to see the whole process with your own parameters, use the command below + # python update_online_pred.py main --experiment_name="your_exp_name" fire.Fire() diff --git a/qlib/config.py b/qlib/config.py index d0479a345..4dedf75d0 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -147,7 +147,7 @@ _default_config = { "mongo": { "task_url": "mongodb://localhost:27017/", "task_db_name": "default_task_db", - } + }, } MODE_CONF = { diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index dcc4ba5d3..a2333cfeb 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -3,9 +3,10 @@ from typing import Callable, Union import pandas as pd from qlib.workflow.task.collect import Collector +from qlib.utils.serial import Serializable -def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs): +def ens_workflow(collector: Collector, process_list, *args, **kwargs): """the ensemble workflow based on collector and different dict processors. Args: @@ -21,7 +22,7 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil Returns: dict: the ensemble dict """ - collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func) + collect_dict = collector.collect() if not isinstance(process_list, list): process_list = [process_list] @@ -37,23 +38,12 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil return ensemble -class Ensemble: +class Ensemble(Serializable): """Merge the objects in an Ensemble.""" - def __init__(self, merge_func=None): - """init Ensemble - - Args: - merge_func (Callable, optional): Given a dict and return the ensemble. - - For example: {Rollinga_b: object, Rollingb_c: object} -> object - - Defaults to None. - """ - self._merge = merge_func - def __call__(self, ensemble_dict: dict, *args, **kwargs): """Merge the ensemble_dict into an ensemble object. + For example: {Rollinga_b: object, Rollingb_c: object} -> object Args: ensemble_dict (dict): the ensemble dict waiting for merging like {name: things} @@ -61,38 +51,29 @@ class Ensemble: Returns: object: the ensemble object """ - if isinstance(getattr(self, "_merge", None), Callable): - return self._merge(ensemble_dict, *args, **kwargs) - else: - raise NotImplementedError(f"Please specify valid merge_func.") + raise NotImplementedError(f"Please implement the `__call__` method.") class RollingEnsemble(Ensemble): """Merge the rolling objects in an Ensemble""" - @staticmethod - def rolling_merge(rolling_dict: dict): + def __call__(self, ensemble_dict: dict, *args, **kwargs): """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. NOTE: The values of dict must be pd.Dataframe, and have the index "datetime" Args: - rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}. + ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}. The key of the dict will be ignored. Returns: pd.Dataframe: the complete result of rolling. """ - artifact_list = list(rolling_dict.values()) + artifact_list = list(ensemble_dict.values()) artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) artifact = pd.concat(artifact_list) # If there are duplicated predition, use the latest perdiction artifact = artifact[~artifact.index.duplicated(keep="last")] artifact = artifact.sort_index() return artifact - - def __init__(self, merge_func=None): - super().__init__(merge_func=merge_func) - if merge_func is None: - self._merge = RollingEnsemble.rolling_merge \ No newline at end of file diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index 1ef3da77f..9cc5db971 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -1,8 +1,9 @@ from qlib.model.ens.ensemble import Ensemble, RollingEnsemble from typing import Callable, Union +from qlib.utils.serial import Serializable -class Group: +class Group(Serializable): """Group the objects based on dict""" def __init__(self, group_func=None, ens: Ensemble = None): @@ -17,8 +18,8 @@ class Group: ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping. """ - self._group = group_func - self._ens = ens + self.group = group_func + self.ens = ens def __call__(self, ungrouped_dict: dict, *args, **kwargs): """Group the ungrouped_dict into different groups. @@ -29,16 +30,16 @@ class Group: Returns: dict: grouped_dict like {G1: object, G2: object} """ - if isinstance(getattr(self, "_group", None), Callable): - grouped_dict = self._group(ungrouped_dict, *args, **kwargs) - if self._ens is not None: + if isinstance(getattr(self, "group", None), Callable): + grouped_dict = self.group(ungrouped_dict, *args, **kwargs) + if self.ens is not None: ens_dict = {} for key, value in grouped_dict.items(): - ens_dict[key] = self._ens(value) + ens_dict[key] = self.ens(value) grouped_dict = ens_dict return grouped_dict else: - raise NotImplementedError(f"Please specify valid merge_func.") + raise NotImplementedError(f"Please specify valid group_func.") class RollingGroup(Group): @@ -65,4 +66,4 @@ class RollingGroup(Group): def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()): super().__init__(group_func=group_func, ens=ens) if group_func is None: - self._group = RollingGroup.rolling_group \ No newline at end of file + self.group = RollingGroup.rolling_group diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index e128e700d..f087cc248 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -3,11 +3,12 @@ from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R +from qlib.workflow.recorder import Recorder from qlib.workflow.record_temp import SignalRecord from qlib.workflow.task.manage import TaskManager, run_task -def task_train(task_config: dict, experiment_name: str) -> str: +def task_train(task_config: dict, experiment_name: str) -> Recorder: """ task based training @@ -20,8 +21,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: Returns ---------- - rid : str - The id of the recorder of this task + Recorder : The instance of the recorder """ # model initiaiton @@ -80,30 +80,40 @@ class TrainerR(Trainer): Assumption: models were defined by `task` and the results will saved to `Recorder` """ - def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs): + def __init__(self, experiment_name, train_func=task_train): + self.experiment_name = experiment_name + self.train_func = train_func + + def train(self, tasks: list, train_func=None, *args, **kwargs): """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. Args: tasks (list): a list of definition based on `task` dict - experiment_name (str): the experiment name - train_func (Callable): the train method which need at least `task` and `experiment_name` + train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. Returns: list: a list of Recorders """ + if train_func is None: + train_func = self.train_func recs = [] for task in tasks: - recs.append(train_func(task, experiment_name, *args, **kwargs)) + recs.append(train_func(task, self.experiment_name, *args, **kwargs)) return recs -class TrainerRM(TrainerR): +class TrainerRM(Trainer): """Trainer based on (R)ecorder and Task(M)anager Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager """ - def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs): + def __init__(self, experiment_name: str, task_pool: str, train_func=task_train): + self.experiment_name = experiment_name + self.task_pool = task_pool + self.train_func = train_func + + def train(self, tasks: list, train_func=None, *args, **kwargs): """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. This method defaults to a single process, but TaskManager offered a great way to parallel training. @@ -111,17 +121,18 @@ class TrainerRM(TrainerR): Args: tasks (list): a list of definition based on `task` dict - experiment_name (str): the experiment name - train_func (Callable): the train method which need at least `task` and `experiment_name` + train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. Returns: list: a list of Recorders """ - tm = TaskManager(task_pool=task_pool) + if train_func is None: + train_func = self.train_func + tm = TaskManager(task_pool=self.task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB - run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs) + run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs) recs = [] for _id in _id_list: recs.append(tm.re_query(_id)["res"]) - return recs \ No newline at end of file + return recs diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 25a368269..0676bfb6b 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -20,7 +20,7 @@ class OnlineManager(Serializable): NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model OFFLINE_TAG = "offline" # the 'offline' model, not for online serving - def __init__(self, trainer: Trainer = None) -> None: + def __init__(self, trainer: Trainer = None): self._trainer = trainer self.logger = get_module_logger(self.__class__.__name__) @@ -81,7 +81,8 @@ class OnlineManagerR(OnlineManager): """ - def __init__(self, experiment_name: str, trainer: Trainer = TrainerR()) -> None: + def __init__(self, experiment_name: str, trainer: Trainer = None): + trainer = TrainerR(experiment_name) super().__init__(trainer) self.logger = get_module_logger(self.__class__.__name__) self.exp_name = experiment_name @@ -105,20 +106,22 @@ class OnlineManagerR(OnlineManager): the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. """ if recorder is None: - recorder = list_recorders( - self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG - ).values() + recorder = list( + list_recorders( + self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG + ).values() + ) if isinstance(recorder, Recorder): recorder = [recorder] if len(recorder) == 0: self.logger.info("No 'next online' model, just use current 'online' models.") return recs = list_recorders(self.exp_name) - self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values()) + self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values())) self.set_online_tag(OnlineManager.ONLINE_TAG, recorder) self.logger.info(f"Reset {len(recorder)} models to 'online'.") - def update_online_pred(self): + def update_online_pred(self, *args, **kwargs): """update all online model predictions to the latest day in Calendar""" mu = ModelUpdater(self.exp_name) cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG) @@ -126,25 +129,24 @@ class OnlineManagerR(OnlineManager): class RollingOnlineManager(OnlineManagerR): - """An implementation of OnlineManager based on Rolling. - - """ + """An implementation of OnlineManager based on Rolling.""" def __init__( self, experiment_name: str, rolling_gen: RollingGen, - trainer: Trainer = TrainerR(), - ) -> None: + trainer: Trainer = None, + ): + trainer = TrainerR(experiment_name) super().__init__(experiment_name, trainer) self.ta = TimeAdjuster() self.rg = rolling_gen self.logger = get_module_logger(self.__class__.__name__) - def prepare_signals(self): + def prepare_signals(self, *args, **kwargs): pass - def prepare_tasks(self): + def prepare_tasks(self, *args, **kwargs): """prepare new tasks based on new date. Returns: @@ -155,7 +157,7 @@ class RollingOnlineManager(OnlineManagerR): ) if max_test is None: self.logger.warn(f"No latest online recorders, no new tasks.") - return None + return [] calendar_latest = self.ta.last_date() if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step: old_tasks = [] @@ -168,7 +170,7 @@ class RollingOnlineManager(OnlineManagerR): new_tasks_tmp = task_generator(old_tasks, self.rg) new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] return new_tasks - return None + return [] def list_latest_recorders(self, rec_filter_func=None): """find latest recorders based on test segments. @@ -187,4 +189,4 @@ class RollingOnlineManager(OnlineManagerR): for rid, rec in recs_flt.items(): if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: latest_rec[rid] = rec - return latest_rec, max_test \ No newline at end of file + return latest_rec, max_test diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 7e555ed06..63d4a6a04 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,9 +1,10 @@ from abc import abstractmethod from typing import Callable, Union from qlib.workflow.task.utils import list_recorders +from qlib.utils.serial import Serializable -class Collector: +class Collector(Serializable): """The collector to collect different results""" def collect(self, *args, **kwargs): @@ -25,33 +26,46 @@ class Collector: class RecorderCollector(Collector): def __init__( - self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None - ) -> None: + self, + exp_name, + artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, + rec_key_func=None, + artifacts_key=None, + rec_filter_func=None, + ): """init RecorderCollector Args: exp_name (str): the name of Experiment artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. """ self.exp_name = exp_name self.artifacts_path = artifacts_path if rec_key_func is None: rec_key_func = lambda rec: rec.info["id"] - self._get_key = rec_key_func + if artifacts_key is None: + artifacts_key = self.artifacts_path.keys() + self.rec_key = rec_key_func + self.artifacts_key = artifacts_key + self.rec_filter = rec_filter_func - def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func, + def collect(self, artifacts_key=None, rec_filter_func=None): """Collect different artifacts based on recorder after filtering. Args: - artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default. + rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default. Returns: dict: the dict after collected like {artifact: {rec_key: object}} """ if artifacts_key is None: - artifacts_key = self.artifacts_path.keys() + artifacts_key = self.artifacts_key + if rec_filter_func is None: + rec_filter_func = self.rec_filter if isinstance(artifacts_key, str): artifacts_key = [artifacts_key] @@ -60,9 +74,9 @@ class RecorderCollector(Collector): # filter records recs_flt = list_recorders(self.exp_name, rec_filter_func) for _, rec in recs_flt.items(): - rec_key = self._get_key(rec) + rec_key = self.rec_key(rec) for key in artifacts_key: artifact = rec.load_object(self.artifacts_path[key]) collect_dict.setdefault(key, {})[rec_key] = artifact - return collect_dict \ No newline at end of file + return collect_dict diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index ddd833aa4..0d6f8c0de 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -49,7 +49,7 @@ class TaskManager: ENCODE_FIELDS_PREFIX = ["def", "res"] - def __init__(self, task_pool=None): + def __init__(self, task_pool: str): """ init Task Manager, remember to make the statement of MongoDB url and database name firstly. @@ -59,9 +59,13 @@ class TaskManager: the name of Collection in MongoDB """ self.mdb = get_mongodb() - self.task_pool = task_pool + self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) + # @property + # def task_pool(self): + # return self._task_pool + def list(self): return self.mdb.list_collection_names() @@ -79,39 +83,39 @@ class TaskManager: task[k] = pickle.loads(task[k]) return task - def _get_task_pool(self, task_pool=None): - if task_pool is None: - task_pool = self.task_pool - if task_pool is None: - raise ValueError("You must specify a task pool.") - if isinstance(task_pool, str): - return getattr(self.mdb, task_pool) - return task_pool + # def _get_task_pool(self, task_pool=None): + # if task_pool is None: + # task_pool = self.task_pool + # if task_pool is None: + # raise ValueError("You must specify a task pool.") + # if isinstance(task_pool, str): + # return getattr(self.mdb, task_pool) + # return task_pool def _dict_to_str(self, flt): return {k: str(v) for k, v in flt.items()} - def replace_task(self, task, new_task, task_pool=None): + def replace_task(self, task, new_task): # assume that the data out of interface was decoded and the data in interface was encoded new_task = self._encode_task(new_task) - task_pool = self._get_task_pool(task_pool) + # task_pool = self._get_task_pool(task_pool) query = {"_id": ObjectId(task["_id"])} try: - task_pool.replace_one(query, new_task) + self.task_pool.replace_one(query, new_task) except InvalidDocument: task["filter"] = self._dict_to_str(task["filter"]) - task_pool.replace_one(query, new_task) + self.task_pool.replace_one(query, new_task) - def insert_task(self, task, task_pool=None): - task_pool = self._get_task_pool(task_pool) + def insert_task(self, task): + # task_pool = self._get_task_pool(task_pool) try: - insert_result = task_pool.insert_one(task) + insert_result = self.task_pool.insert_one(task) except InvalidDocument: task["filter"] = self._dict_to_str(task["filter"]) - insert_result = task_pool.insert_one(task) + insert_result = self.task_pool.insert_one(task) return insert_result - def insert_task_def(self, task_def, task_pool=None): + def insert_task_def(self, task_def): """ insert a task to task_pool @@ -126,7 +130,7 @@ class TaskManager: ------- """ - task_pool = self._get_task_pool(task_pool) + # task_pool = self._get_task_pool(task_pool) task = self._encode_task( { "def": task_def, @@ -134,10 +138,10 @@ class TaskManager: "status": self.STATUS_WAITING, } ) - insert_result = self.insert_task(task, task_pool) + insert_result = self.insert_task(task) return insert_result - def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False): + def create_task(self, task_def_l, dry_run=False, print_nt=False): """ if the tasks in task_def_l is new, then insert new tasks into the task_pool @@ -156,13 +160,13 @@ class TaskManager: list a list of the _id of new tasks """ - task_pool = self._get_task_pool(task_pool) + # task_pool = self._get_task_pool(task_pool) new_tasks = [] for t in task_def_l: try: - r = task_pool.find_one({"filter": t}) + r = self.task_pool.find_one({"filter": t}) except InvalidDocument: - r = task_pool.find_one({"filter": self._dict_to_str(t)}) + r = self.task_pool.find_one({"filter": self._dict_to_str(t)}) if r is None: new_tasks.append(t) self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}") @@ -176,18 +180,18 @@ class TaskManager: _id_list = [] for t in new_tasks: - insert_result = self.insert_task_def(t, task_pool) + insert_result = self.insert_task_def(t) _id_list.append(insert_result.inserted_id) return _id_list - def fetch_task(self, query={}, task_pool=None): - task_pool = self._get_task_pool(task_pool) + def fetch_task(self, query={}): + # task_pool = self._get_task_pool(task_pool) query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) query.update({"status": self.STATUS_WAITING}) - task = task_pool.find_one_and_update( + task = self.task_pool.find_one_and_update( query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] ) # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority @@ -197,7 +201,7 @@ class TaskManager: return self._decode_task(task) @contextmanager - def safe_fetch_task(self, query={}, task_pool=None): + def safe_fetch_task(self, query={}): """ fetch task from task_pool using query with contextmanager @@ -212,7 +216,7 @@ class TaskManager: ------- """ - task = self.fetch_task(query=query, task_pool=task_pool) + task = self.fetch_task(query=query) try: yield task except Exception: @@ -229,7 +233,7 @@ class TaskManager: break yield task - def query(self, query={}, decode=True, task_pool=None): + def query(self, query={}, decode=True): """ This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator @@ -248,29 +252,30 @@ class TaskManager: query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) - task_pool = self._get_task_pool(task_pool) - for t in task_pool.find(query): + # task_pool = self._get_task_pool(task_pool) + for t in self.task_pool.find(query): yield self._decode_task(t) - def re_query(self, _id, task_pool=None): - task_pool = self._get_task_pool(task_pool) - return task_pool.find_one({"_id": ObjectId(_id)}) + def re_query(self, _id): + # task_pool = self._get_task_pool(task_pool) + t = self.task_pool.find_one({"_id": ObjectId(_id)}) + return self._decode_task(t) - def commit_task_res(self, task, res, status=None, task_pool=None): - task_pool = self._get_task_pool(task_pool) + def commit_task_res(self, task, res, status=None): + # task_pool = self._get_task_pool(task_pool) # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE - task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) + self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) - def return_task(self, task, status=None, task_pool=None): - task_pool = self._get_task_pool(task_pool) + def return_task(self, task, status=None): + # task_pool = self._get_task_pool(task_pool) if status is None: status = TaskManager.STATUS_WAITING update_dict = {"$set": {"status": status}} - task_pool.update_one({"_id": task["_id"]}, update_dict) + self.task_pool.update_one({"_id": task["_id"]}, update_dict) - def remove(self, query={}, task_pool=None): + def remove(self, query={}): """ remove the task using query @@ -286,16 +291,16 @@ class TaskManager: """ query = query.copy() - task_pool = self._get_task_pool(task_pool) + # task_pool = self._get_task_pool(task_pool) if "_id" in query: query["_id"] = ObjectId(query["_id"]) - task_pool.delete_many(query) + self.task_pool.delete_many(query) - def task_stat(self, query={}, task_pool=None): + def task_stat(self, query={}): query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) - tasks = self.query(task_pool=task_pool, query=query, decode=False) + tasks = self.query(query=query, decode=False) status_stat = {} for t in tasks: status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1 @@ -306,14 +311,14 @@ class TaskManager: # default query if "status" not in query: query["status"] = self.STATUS_RUNNING - return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool) + return self.reset_status(query=query, status=self.STATUS_WAITING) - def reset_status(self, query, status, task_pool=None): + def reset_status(self, query, status): query = query.copy() - task_pool = self._get_task_pool(task_pool) + # task_pool = self._get_task_pool(task_pool) if "_id" in query: query["_id"] = ObjectId(query["_id"]) - print(task_pool.update_many(query, {"$set": {"status": status}})) + print(self.task_pool.update_many(query, {"$set": {"status": status}})) def _get_undone_n(self, task_stat): return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0) @@ -321,14 +326,14 @@ class TaskManager: def _get_total(self, task_stat): return sum(task_stat.values()) - def wait(self, query={}, task_pool=None): - task_stat = self.task_stat(query, task_pool) + def wait(self, query={}): + task_stat = self.task_stat(query) total = self._get_total(task_stat) last_undone_n = self._get_undone_n(task_stat) with tqdm(total=total, initial=total - last_undone_n) as pbar: while True: time.sleep(10) - undone_n = self._get_undone_n(self.task_stat(query, task_pool)) + undone_n = self._get_undone_n(self.task_stat(query)) pbar.update(last_undone_n - undone_n) last_undone_n = undone_n if undone_n == 0: @@ -365,7 +370,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): break get_module_logger("run_task").info(task["def"]) if force_release: - with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means? res = executor.submit(task_func, task["def"], *args, **kwargs).result() else: res = task_func(task["def"], *args, **kwargs)