mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Callable Exp (#683)
This commit is contained in:
@@ -8,7 +8,7 @@ import warnings
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
@@ -270,7 +270,13 @@ class SigAnaRecord(RecordTemp):
|
||||
self.label_col = label_col
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
def generate(self, **kwargs):
|
||||
def generate(self, label: Optional[pd.DataFrame] = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
label : Optional[pd.DataFrame]
|
||||
Label should be a dataframe.
|
||||
"""
|
||||
if self.skip_existing:
|
||||
try:
|
||||
self.check(include_self=True, parents=False)
|
||||
@@ -283,7 +289,8 @@ class SigAnaRecord(RecordTemp):
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
if label is None:
|
||||
label = self.load("label.pkl")
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.exp import Experiment
|
||||
|
||||
|
||||
class Collector(Serializable):
|
||||
@@ -146,7 +147,9 @@ class RecorderCollector(Collector):
|
||||
Init RecorderCollector.
|
||||
|
||||
Args:
|
||||
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
experiment:
|
||||
(Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
(Callable): an callable function, which returns a list of experiments
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
@@ -157,6 +160,7 @@ class RecorderCollector(Collector):
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
assert isinstance(experiment, (Experiment, Callable))
|
||||
self.experiment = experiment
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
@@ -192,15 +196,16 @@ class RecorderCollector(Collector):
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
|
||||
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
|
||||
recs = self.experiment.list_recorders(**self.list_kwargs)
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
if isinstance(self.experiment, Experiment):
|
||||
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
|
||||
recs = list(self.experiment.list_recorders(**self.list_kwargs).values())
|
||||
elif isinstance(self.experiment, Callable):
|
||||
recs = self.experiment()
|
||||
|
||||
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
for _, rec in recs_flt.items():
|
||||
for rec in recs:
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
if self.ART_KEY_RAW == key:
|
||||
|
||||
Reference in New Issue
Block a user