mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
update collector
This commit is contained in:
@@ -140,7 +140,7 @@ class RollingTaskExample:
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
@@ -13,12 +13,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs):
|
||||
collector (Collector): the collector to collect the result into {result_key: things}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
The processor order is same as the list order.
|
||||
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
|
||||
artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
Returns:
|
||||
dict: the ensemble dict
|
||||
"""
|
||||
@@ -38,7 +33,7 @@ def ens_workflow(collector: Collector, process_list, *args, **kwargs):
|
||||
return ensemble
|
||||
|
||||
|
||||
class Ensemble(Serializable):
|
||||
class Ensemble:
|
||||
"""Merge the objects in an Ensemble."""
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
|
||||
@@ -389,7 +389,7 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
|
||||
return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
|
||||
|
||||
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
|
||||
"""
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class Collector:
|
||||
"""The collector to collect different results"""
|
||||
|
||||
def collect(self, *args, **kwargs):
|
||||
def __init__(self, process_list=[]):
|
||||
"""
|
||||
Args:
|
||||
process_list (list, optional): process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
self.process_list = process_list
|
||||
|
||||
def collect(self):
|
||||
"""Collect the results and return a dict like {key: things}
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected.
|
||||
dict: the dict after collecting.
|
||||
|
||||
For example:
|
||||
|
||||
@@ -23,13 +34,88 @@ class Collector:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `collect` method.")
|
||||
|
||||
@staticmethod
|
||||
def process_collect(collected_dict, process_list=[], *args, **kwargs):
|
||||
"""do a series of processing to the dict returned by collect and return a dict like {key: things}
|
||||
For example: you can group and ensemble.
|
||||
|
||||
Args:
|
||||
collected_dict (dict): the dict return by `collect`
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
The processor order is same as the list order.
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
|
||||
Returns:
|
||||
dict: the dict after processing.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
result = {}
|
||||
for artifact in collected_dict:
|
||||
value = collected_dict[artifact]
|
||||
for process in process_list:
|
||||
if not callable(process):
|
||||
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
|
||||
value = process(value, *args, **kwargs)
|
||||
result[artifact] = value
|
||||
return result
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
do the workflow including collect and process_collect
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting and processing.
|
||||
"""
|
||||
collected = self.collect()
|
||||
return self.process_collect(collected, self.process_list, *args, **kwargs)
|
||||
|
||||
def save(self, filepath):
|
||||
"""
|
||||
save the collector into a file
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Returns:
|
||||
bool: if successed
|
||||
"""
|
||||
try:
|
||||
with open(filepath, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def load(filepath):
|
||||
"""
|
||||
load the collector from a file
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
collector = pickle.load(f)
|
||||
if isinstance(collector, Collector):
|
||||
return collector
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exp_name,
|
||||
experiment,
|
||||
process_list=[],
|
||||
rec_key_func=None,
|
||||
rec_filter_func=None,
|
||||
artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"},
|
||||
@@ -38,13 +124,17 @@ class RecorderCollector(Collector):
|
||||
"""init RecorderCollector
|
||||
|
||||
Args:
|
||||
exp_name (str): the name of Experiment
|
||||
experiment (Experiment or str): an instance of a Experiment or the name of a Experiment
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
self.experiment = experiment
|
||||
self.process_list = process_list
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
@@ -74,7 +164,12 @@ class RecorderCollector(Collector):
|
||||
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
recs = self.experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self._rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
|
||||
Reference in New Issue
Block a user