1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00
Files
qlib/qlib/workflow/task/collect.py
2021-03-03 16:36:15 +08:00

52 lines
1.6 KiB
Python

from qlib.workflow import R
import pandas as pd
from typing import Union
from tqdm.auto import tqdm
class RollingCollector:
"""
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
"""
# TODO: speed up this class
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config
self.flt_func = flt_func # determine whether a task can be retained based on task config
def __call__(self, exp_name) -> Union[pd.Series, dict]:
# TODO;
# Should we split the scripts into several sub functions?
exp = R.get_exp(experiment_name=exp_name)
# filter records
recs = exp.list_recorders()
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
params = rec.load_object("param")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):
rec.params = params
recs_flt[rid] = rec
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
group_key = self.get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
return reduce_group