1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

update collector

This commit is contained in:
lzh222333
2021-04-27 04:12:08 +00:00
parent 0058f7d0dc
commit 42f510024c
5 changed files with 106 additions and 15 deletions

View File

@@ -140,7 +140,7 @@ class RollingTaskExample:
return False
artifact = ens_workflow(
RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RollingGroup(),
)
print(artifact)

View File

@@ -13,12 +13,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs):
collector (Collector): the collector to collect the result into {result_key: things}
process_list (list or Callable): the list of processors or the instance of processor to process dict.
The processor order is same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
dict: the ensemble dict
"""
@@ -38,7 +33,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs):
return ensemble
class Ensemble(Serializable):
class Ensemble:
"""Merge the objects in an Ensemble."""
def __call__(self, ensemble_dict: dict, *args, **kwargs):

View File

@@ -389,7 +389,7 @@ class RollingOnlineManager(OnlineManagerR):
if rec_key_func is None:
rec_key_func = rec_key
return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
"""

View File

@@ -1,17 +1,28 @@
from abc import abstractmethod
from typing import Callable, Union
from qlib.workflow import R
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
import dill as pickle
class Collector:
"""The collector to collect different results"""
def collect(self, *args, **kwargs):
def __init__(self, process_list=[]):
"""
Args:
process_list (list, optional): process_list (list or Callable): the list of processors or the instance of processor to process dict.
"""
if not isinstance(process_list, list):
process_list = [process_list]
self.process_list = process_list
def collect(self):
"""Collect the results and return a dict like {key: things}
Returns:
dict: the dict after collected.
dict: the dict after collecting.
For example:
@@ -23,13 +34,88 @@ class Collector:
"""
raise NotImplementedError(f"Please implement the `collect` method.")
@staticmethod
def process_collect(collected_dict, process_list=[], *args, **kwargs):
"""do a series of processing to the dict returned by collect and return a dict like {key: things}
For example: you can group and ensemble.
Args:
collected_dict (dict): the dict return by `collect`
process_list (list or Callable): the list of processors or the instance of processor to process dict.
The processor order is same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
dict: the dict after processing.
"""
if not isinstance(process_list, list):
process_list = [process_list]
result = {}
for artifact in collected_dict:
value = collected_dict[artifact]
for process in process_list:
if not callable(process):
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
value = process(value, *args, **kwargs)
result[artifact] = value
return result
def __call__(self, *args, **kwargs):
"""
do the workflow including collect and process_collect
Returns:
dict: the dict after collecting and processing.
"""
collected = self.collect()
return self.process_collect(collected, self.process_list, *args, **kwargs)
def save(self, filepath):
"""
save the collector into a file
Args:
filepath (str): the path of file
Returns:
bool: if successed
"""
try:
with open(filepath, "wb") as f:
pickle.dump(self, f)
except Exception:
return False
return True
@staticmethod
def load(filepath):
"""
load the collector from a file
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `Collector`
Returns:
Collector: the instance of Collector
"""
with open(filepath, "rb") as f:
collector = pickle.load(f)
if isinstance(collector, Collector):
return collector
else:
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
def __init__(
self,
exp_name,
experiment,
process_list=[],
rec_key_func=None,
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"},
@@ -38,13 +124,17 @@ class RecorderCollector(Collector):
"""init RecorderCollector
Args:
exp_name (str): the name of Experiment
experiment (Experiment or str): an instance of a Experiment or the name of a Experiment
process_list (list or Callable): the list of processors or the instance of processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
"""
self.exp_name = exp_name
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
self.experiment = experiment
self.process_list = process_list
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
@@ -74,7 +164,12 @@ class RecorderCollector(Collector):
collect_dict = {}
# filter records
recs_flt = list_recorders(self.exp_name, rec_filter_func)
recs = self.experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
for _, rec in recs_flt.items():
rec_key = self._rec_key_func(rec)
for key in artifacts_key:

View File

@@ -57,6 +57,7 @@ REQUIRED = [
"ruamel.yaml>=0.16.12",
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
]
# Numpy include