diff --git a/docs/advanced/task_managment.rst b/docs/advanced/task_management.rst similarity index 100% rename from docs/advanced/task_managment.rst rename to docs/advanced/task_management.rst diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 059871ab1..2e4746f59 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,6 +1,7 @@ from qlib.workflow import R import pandas as pd from typing import Union +from typing import Callable from qlib import get_module_logger @@ -9,9 +10,63 @@ class TaskCollector: Collect the record results of the finished tasks with key and filter """ - @staticmethod + def __init__(self, experiment_name: str) -> None: + self.exp_name = experiment_name + self.exp = R.get_exp(experiment_name=experiment_name) + self.logger = get_module_logger("TaskCollector") + + def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False): + """ + Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter. + If you don't train with "task_train", then there is no "task.pkl" which includes the task config. + If there is a "task.pkl", then it will become rec.task which can be get simply. + + Parameters + ---------- + rec_filter_func : Callable[[MLflowRecorder], bool], optional + judge whether you need this recorder, by default None + task_filter_func : Callable[[dict], bool], optional + judge whether you need this task, by default None + only_finished : bool, optional + whether always use finished recorder, by default True + only_have_task : bool, optional + whether it is necessary to get the task config + + Returns + ------- + dict + a dict of {rid:recorder} + + Raises + ------ + OSError + if you use a task filter, but there is no "task.pkl" which includes the task config + """ + recs = self.exp.list_recorders() + # return all recorders if the filter is None and you don't need task + if rec_filter_func==None and task_filter_func==None and only_have_task==False: + return recs + recs_flt = {} + for rid, rec in recs.items(): + if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False: + if rec_filter_func is None or rec_filter_func(rec): + task = None + try: + task = rec.load_object("task.pkl") + except OSError: + if task_filter_func is not None: + raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?') + if task is None and only_have_task: + continue + + if task_filter_func is None or task_filter_func(task): + rec.task = task + recs_flt[rid] = rec + + return recs_flt + def collect_predictions( - experiment_name: str, + self, get_key_func, filter_func=None, ): @@ -27,24 +82,15 @@ class TaskCollector: Returns ------- - + dict + the dict of predictions """ - exp = R.get_exp(experiment_name=experiment_name) - # filter records - recs = exp.list_recorders() - - recs_flt = {} - for rid, rec in recs.items(): - params = rec.load_object("task.pkl") - if rec.status == rec.STATUS_FI: - if filter_func is None or filter_func(params): - rec.params = params - recs_flt[rid] = rec + recs_flt = self.list_recorders(task_filter_func=filter_func) # group recs_group = {} for _, rec in recs_flt.items(): - params = rec.params + params = rec.task group_key = get_key_func(params) recs_group.setdefault(group_key, []).append(rec) @@ -57,9 +103,26 @@ class TaskCollector: pred = pd.concat(pred_l).sort_index() reduce_group[k] = pred - get_module_logger("TaskCollector").info(f"Collect {len(reduce_group)} predictions in {experiment_name}") + self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}") return reduce_group + def collect_latest_records( + self, + filter_func=None, + ): + recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True) + + max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values()) + + latest_record = {} + for rid, rec in recs_flt.items(): + if rec.task['dataset']['kwargs']['segments']['test'] == max_test: + latest_record[rid] = rec + + self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}") + return latest_record + + class RollingCollector: """ diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index ae4aee147..f27d02594 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -363,7 +363,3 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): return ever_run - -if __name__ == "__main__": - auto_init() - Fire(TaskManager) diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index f9d03efbc..5127a87da 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -6,7 +6,7 @@ import pandas as pd from qlib.utils import init_instance_by_config from qlib import get_module_logger from qlib.workflow import R - +from qlib.model.trainer import task_train class ModelUpdater: """ @@ -136,7 +136,7 @@ class ModelUpdater: def online_filter(self, record): tags = record.list_tags() - if tags[self.ONLINE_TAG] == self.ONLINE_TAG_TRUE: + if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE: return True return False @@ -146,6 +146,13 @@ class ModelUpdater: self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.") def list_online_model(self): + """list the record of online model + + Returns + ------- + dict + {rid : record of the online model} + """ recs = self.exp.list_recorders() online_rec = {} for rid, rec in recs.items():