From 42f510024cfbff7ce412a95e1ad7c05c85f59ec1 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Tue, 27 Apr 2021 04:12:08 +0000 Subject: [PATCH] update collector --- .../model_rolling/task_manager_rolling.py | 2 +- qlib/model/ens/ensemble.py | 9 +- qlib/workflow/online/manager.py | 2 +- qlib/workflow/task/collect.py | 107 +++++++++++++++++- setup.py | 1 + 5 files changed, 106 insertions(+), 15 deletions(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 9c1cbf891..ab3a4eee5 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -140,7 +140,7 @@ class RollingTaskExample: return False artifact = ens_workflow( - RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter), + RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(), ) print(artifact) diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 942303c18..63f6438c2 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -13,12 +13,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs): collector (Collector): the collector to collect the result into {result_key: things} process_list (list or Callable): the list of processors or the instance of processor to process dict. The processor order is same as the list order. - - For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] - - artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. - + For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] Returns: dict: the ensemble dict """ @@ -38,7 +33,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs): return ensemble -class Ensemble(Serializable): +class Ensemble: """Merge the objects in an Ensemble.""" def __call__(self, ensemble_dict: dict, *args, **kwargs): diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index c94cf2455..e107271d0 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -389,7 +389,7 @@ class RollingOnlineManager(OnlineManagerR): if rec_key_func is None: rec_key_func = rec_key - return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func) + return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func) def collect_artifact(self, rec_key_func=None, rec_filter_func=None): """ diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index ef6a7a7d4..b4c81122d 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,17 +1,28 @@ from abc import abstractmethod from typing import Callable, Union +from qlib.workflow import R from qlib.workflow.task.utils import list_recorders from qlib.utils.serial import Serializable +import dill as pickle class Collector: """The collector to collect different results""" - def collect(self, *args, **kwargs): + def __init__(self, process_list=[]): + """ + Args: + process_list (list, optional): process_list (list or Callable): the list of processors or the instance of processor to process dict. + """ + if not isinstance(process_list, list): + process_list = [process_list] + self.process_list = process_list + + def collect(self): """Collect the results and return a dict like {key: things} Returns: - dict: the dict after collected. + dict: the dict after collecting. For example: @@ -23,13 +34,88 @@ class Collector: """ raise NotImplementedError(f"Please implement the `collect` method.") + @staticmethod + def process_collect(collected_dict, process_list=[], *args, **kwargs): + """do a series of processing to the dict returned by collect and return a dict like {key: things} + For example: you can group and ensemble. + + Args: + collected_dict (dict): the dict return by `collect` + process_list (list or Callable): the list of processors or the instance of processor to process dict. + The processor order is same as the list order. + For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] + + Returns: + dict: the dict after processing. + """ + if not isinstance(process_list, list): + process_list = [process_list] + result = {} + for artifact in collected_dict: + value = collected_dict[artifact] + for process in process_list: + if not callable(process): + raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.") + value = process(value, *args, **kwargs) + result[artifact] = value + return result + + def __call__(self, *args, **kwargs): + """ + do the workflow including collect and process_collect + + Returns: + dict: the dict after collecting and processing. + """ + collected = self.collect() + return self.process_collect(collected, self.process_list, *args, **kwargs) + + def save(self, filepath): + """ + save the collector into a file + + Args: + filepath (str): the path of file + + Returns: + bool: if successed + """ + try: + with open(filepath, "wb") as f: + pickle.dump(self, f) + except Exception: + return False + return True + + @staticmethod + def load(filepath): + """ + load the collector from a file + + Args: + filepath (str): the path of file + + Raises: + TypeError: the pickled file must be `Collector` + + Returns: + Collector: the instance of Collector + """ + with open(filepath, "rb") as f: + collector = pickle.load(f) + if isinstance(collector, Collector): + return collector + else: + raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!") + class RecorderCollector(Collector): ART_KEY_RAW = "__raw" def __init__( self, - exp_name, + experiment, + process_list=[], rec_key_func=None, rec_filter_func=None, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, @@ -38,13 +124,17 @@ class RecorderCollector(Collector): """init RecorderCollector Args: - exp_name (str): the name of Experiment + experiment (Experiment or str): an instance of a Experiment or the name of a Experiment + process_list (list or Callable): the list of processors or the instance of processor to process dict. rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. """ - self.exp_name = exp_name + if isinstance(experiment, str): + experiment = R.get_exp(experiment_name=experiment) + self.experiment = experiment + self.process_list = process_list self.artifacts_path = artifacts_path if rec_key_func is None: rec_key_func = lambda rec: rec.info["id"] @@ -74,7 +164,12 @@ class RecorderCollector(Collector): collect_dict = {} # filter records - recs_flt = list_recorders(self.exp_name, rec_filter_func) + recs = self.experiment.list_recorders() + recs_flt = {} + for rid, rec in recs.items(): + if rec_filter_func is None or rec_filter_func(rec): + recs_flt[rid] = rec + for _, rec in recs_flt.items(): rec_key = self._rec_key_func(rec) for key in artifacts_key: diff --git a/setup.py b/setup.py index 699fdf75d..c90d7d1c3 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ REQUIRED = [ "ruamel.yaml>=0.16.12", "pymongo==3.7.2", # For task management "scikit-learn>=0.22", + "dill", ] # Numpy include