diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 36ccf434d..c7d82d541 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -139,6 +139,7 @@ class RecorderCollector(Collector): rec_filter_func=None, artifacts_path={"pred": "pred.pkl"}, artifacts_key=None, + filter_string: str = "" ): """ Init RecorderCollector. @@ -150,6 +151,7 @@ class RecorderCollector(Collector): rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. + filter_string (str): filter string that used to apply in recorder quering (only support mlflow for now). """ super().__init__(process_list=process_list) if isinstance(experiment, str): @@ -163,6 +165,7 @@ class RecorderCollector(Collector): self.rec_key_func = rec_key_func self.artifacts_key = artifacts_key self.rec_filter_func = rec_filter_func + self.filter_string = filter_string def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict: """ @@ -187,7 +190,7 @@ class RecorderCollector(Collector): collect_dict = {} # filter records - recs = self.experiment.list_recorders() + recs = self.experiment.list_recorders(filter_string=self.filter_string) recs_flt = {} for rid, rec in recs.items(): if rec_filter_func is None or rec_filter_func(rec):