mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
202 lines
7.3 KiB
Python
202 lines
7.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
"""
|
|
Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on.
|
|
"""
|
|
|
|
from typing import Callable, Dict, List
|
|
from qlib.utils.serial import Serializable
|
|
from qlib.workflow import R
|
|
|
|
|
|
class Collector(Serializable):
|
|
"""The collector to collect different results"""
|
|
|
|
pickle_backend = "dill" # use dill to dump user method
|
|
|
|
def __init__(self, process_list=[]):
|
|
"""
|
|
Args:
|
|
process_list (list, optional): process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
|
"""
|
|
if not isinstance(process_list, list):
|
|
process_list = [process_list]
|
|
self.process_list = process_list
|
|
|
|
def collect(self) -> dict:
|
|
"""Collect the results and return a dict like {key: things}
|
|
|
|
Returns:
|
|
dict: the dict after collecting.
|
|
|
|
For example:
|
|
|
|
{"prediction": pd.Series}
|
|
|
|
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
|
|
|
......
|
|
"""
|
|
raise NotImplementedError(f"Please implement the `collect` method.")
|
|
|
|
@staticmethod
|
|
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
|
|
"""do a series of processing to the dict returned by collect and return a dict like {key: things}
|
|
For example: you can group and ensemble.
|
|
|
|
Args:
|
|
collected_dict (dict): the dict return by `collect`
|
|
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
|
The processor order is same as the list order.
|
|
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
|
|
|
Returns:
|
|
dict: the dict after processing.
|
|
"""
|
|
if not isinstance(process_list, list):
|
|
process_list = [process_list]
|
|
result = {}
|
|
for artifact in collected_dict:
|
|
value = collected_dict[artifact]
|
|
for process in process_list:
|
|
if not callable(process):
|
|
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
|
|
value = process(value, *args, **kwargs)
|
|
result[artifact] = value
|
|
return result
|
|
|
|
def __call__(self, *args, **kwargs) -> dict:
|
|
"""
|
|
do the workflow including collect and process_collect
|
|
|
|
Returns:
|
|
dict: the dict after collecting and processing.
|
|
"""
|
|
collected = self.collect()
|
|
return self.process_collect(collected, self.process_list, *args, **kwargs)
|
|
|
|
|
|
class MergeCollector(Collector):
|
|
"""
|
|
A collector to collect the results of other Collectors
|
|
|
|
For example:
|
|
|
|
We have 2 collector, which named A and B.
|
|
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
|
|
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
|
|
|
......
|
|
|
|
"""
|
|
|
|
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []):
|
|
"""
|
|
Args:
|
|
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
|
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
|
"""
|
|
super().__init__(process_list=process_list)
|
|
self.collector_dict = collector_dict
|
|
|
|
def collect(self) -> dict:
|
|
"""
|
|
Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key)
|
|
|
|
Returns:
|
|
dict: the dict after collecting.
|
|
"""
|
|
collect_dict = {}
|
|
for collector_key, collector in self.collector_dict.items():
|
|
tmp_dict = collector()
|
|
for key, value in tmp_dict.items():
|
|
collect_dict[collector_key + "_" + str(key)] = value
|
|
return collect_dict
|
|
|
|
|
|
class RecorderCollector(Collector):
|
|
ART_KEY_RAW = "__raw"
|
|
|
|
def __init__(
|
|
self,
|
|
experiment,
|
|
process_list=[],
|
|
rec_key_func=None,
|
|
rec_filter_func=None,
|
|
artifacts_path={"pred": "pred.pkl"},
|
|
artifacts_key=None,
|
|
):
|
|
"""init RecorderCollector
|
|
|
|
Args:
|
|
experiment (Experiment or str): an instance of a Experiment or the name of a Experiment
|
|
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
|
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
|
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
|
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
|
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
|
"""
|
|
super().__init__(process_list=process_list)
|
|
if isinstance(experiment, str):
|
|
experiment = R.get_exp(experiment_name=experiment)
|
|
self.experiment = experiment
|
|
self.artifacts_path = artifacts_path
|
|
if rec_key_func is None:
|
|
rec_key_func = lambda rec: rec.info["id"]
|
|
if artifacts_key is None:
|
|
artifacts_key = list(self.artifacts_path.keys())
|
|
self._rec_key_func = rec_key_func
|
|
self.artifacts_key = artifacts_key
|
|
self._rec_filter_func = rec_filter_func
|
|
|
|
def collect(self, artifacts_key=None, rec_filter_func=None) -> dict:
|
|
"""Collect different artifacts based on recorder after filtering.
|
|
|
|
Args:
|
|
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
|
|
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
|
|
|
|
Returns:
|
|
dict: the dict after collected like {artifact: {rec_key: object}}
|
|
"""
|
|
if artifacts_key is None:
|
|
artifacts_key = self.artifacts_key
|
|
if rec_filter_func is None:
|
|
rec_filter_func = self._rec_filter_func
|
|
|
|
if isinstance(artifacts_key, str):
|
|
artifacts_key = [artifacts_key]
|
|
|
|
collect_dict = {}
|
|
# filter records
|
|
recs = self.experiment.list_recorders()
|
|
recs_flt = {}
|
|
for rid, rec in recs.items():
|
|
if rec_filter_func is None or rec_filter_func(rec):
|
|
recs_flt[rid] = rec
|
|
|
|
for _, rec in recs_flt.items():
|
|
rec_key = self._rec_key_func(rec)
|
|
for key in artifacts_key:
|
|
if self.ART_KEY_RAW == key:
|
|
artifact = rec
|
|
else:
|
|
# only collect existing artifact
|
|
try:
|
|
artifact = rec.load_object(self.artifacts_path[key])
|
|
except Exception:
|
|
continue
|
|
collect_dict.setdefault(key, {})[rec_key] = artifact
|
|
|
|
return collect_dict
|
|
|
|
def get_exp_name(self) -> str:
|
|
"""
|
|
Get experiment name
|
|
|
|
Returns:
|
|
str: experiment name
|
|
"""
|
|
return self.experiment.name
|