1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

Callable Exp (#683)

This commit is contained in:
you-n-g
2021-11-12 14:56:22 +08:00
committed by GitHub
parent 01bdf6c1b1
commit f2771f1beb
2 changed files with 23 additions and 11 deletions

View File

@@ -8,7 +8,7 @@ import warnings
import pandas as pd
from pathlib import Path
from pprint import pprint
from typing import Union, List
from typing import Union, List, Optional
from collections import defaultdict
from qlib.utils.exceptions import LoadObjectError
@@ -270,7 +270,13 @@ class SigAnaRecord(RecordTemp):
self.label_col = label_col
self.skip_existing = skip_existing
def generate(self, **kwargs):
def generate(self, label: Optional[pd.DataFrame] = None, **kwargs):
"""
Parameters
----------
label : Optional[pd.DataFrame]
Label should be a dataframe.
"""
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
@@ -283,7 +289,8 @@ class SigAnaRecord(RecordTemp):
self.check()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
if label is None:
label = self.load("label.pkl")
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return

View File

@@ -10,6 +10,7 @@ from typing import Callable, Dict, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
from qlib.workflow.exp import Experiment
class Collector(Serializable):
@@ -146,7 +147,9 @@ class RecorderCollector(Collector):
Init RecorderCollector.
Args:
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
experiment:
(Experiment or str): an instance of an Experiment or the name of an Experiment
(Callable): an callable function, which returns a list of experiments
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
@@ -157,6 +160,7 @@ class RecorderCollector(Collector):
super().__init__(process_list=process_list)
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
assert isinstance(experiment, (Experiment, Callable))
self.experiment = experiment
self.artifacts_path = artifacts_path
if rec_key_func is None:
@@ -192,15 +196,16 @@ class RecorderCollector(Collector):
collect_dict = {}
# filter records
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
recs = self.experiment.list_recorders(**self.list_kwargs)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
if isinstance(self.experiment, Experiment):
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
recs = list(self.experiment.list_recorders(**self.list_kwargs).values())
elif isinstance(self.experiment, Callable):
recs = self.experiment()
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
logger = get_module_logger("RecorderCollector")
for _, rec in recs_flt.items():
for rec in recs:
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
if self.ART_KEY_RAW == key: