From def132e1407bc97585efa2d261feefd8386c34f6 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Mon, 8 Mar 2021 16:10:16 +0800 Subject: [PATCH] modified format and added TaskCollector --- qlib/model/trainer.py | 6 ++-- qlib/workflow/task/collect.py | 58 ++++++++++++++++++++++++++++++++++- qlib/workflow/task/gen.py | 1 + qlib/workflow/task/utils.py | 1 + 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 71cf9061f..91061636d 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -6,7 +6,7 @@ from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord -def task_train(task_config: dict, experiment_name): +def task_train(task_config: dict, experiment_name: str): """ task based training @@ -14,6 +14,8 @@ def task_train(task_config: dict, experiment_name): ---------- task_config : dict A dict describes a task setting. + experiment_name: str + The name of experiment """ # model initiaiton @@ -30,7 +32,7 @@ def task_train(task_config: dict, experiment_name): R.save_objects(param=task_config) # keep the original format and datatype # generate records: prediction, backtest, and analysis - records = task_config.get('record', []) + records = task_config.get("record", []) if isinstance(records, dict): # prevent only one dict records = [records] for record in records: diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 9a67d8e06..4562a1cec 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -4,6 +4,62 @@ from typing import Union from tqdm.auto import tqdm +class TaskCollector: + """ + Collect the record results of the finished tasks with key and filter + """ + + @staticmethod + def collect( + experiment_name: str, + get_key_func, + filter_func=None, + ): + """ + + Parameters + ---------- + experiment_name : str + get_key_func : function(task: dict) -> Union[Number, str, tuple] + get the key of a task when collect it + filter_func : function(task: dict) -> bool + to judge a task will be collected or not + + Returns + ------- + + """ + exp = R.get_exp(experiment_name=experiment_name) + # filter records + recs = exp.list_recorders() + + recs_flt = {} + for rid, rec in tqdm(recs.items(), desc="Loading data"): + params = rec.load_object("param") + if rec.status == rec.STATUS_FI: + if filter_func is None or filter_func(params): + rec.params = params + recs_flt[rid] = rec + + # group + recs_group = {} + for _, rec in recs_flt.items(): + params = rec.params + group_key = get_key_func(params) + recs_group.setdefault(group_key, []).append(rec) + + # reduce group + reduce_group = {} + for k, rec_l in recs_group.items(): + pred_l = [] + for rec in rec_l: + pred_l.append(rec.load_object("pred.pkl").iloc[:, 0]) + pred = pd.concat(pred_l).sort_index() + reduce_group[k] = pred + + return reduce_group + + class RollingCollector: """ Rolling Models Ensemble based on (R)ecord @@ -13,7 +69,7 @@ class RollingCollector: # TODO: speed up this class def __init__(self, get_key_func, flt_func=None): - self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config + self.get_key_func = get_key_func # get the key of a task based on task config self.flt_func = flt_func # determine whether a task can be retained based on task config def __call__(self, exp_name) -> Union[pd.Series, dict]: diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 60fc5c221..b1c2e0ce2 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -79,6 +79,7 @@ class TaskGen(metaclass=abc.ABCMeta): output: a set of tasks with different losses """ + @abc.abstractmethod def generate(self, task: dict) -> typing.List[dict]: """ diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 5e94f55ae..63563e2f6 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -8,6 +8,7 @@ from qlib.log import get_module_logger from pymongo import MongoClient from typing import Union + def get_mongodb(): """