1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

modified format and added TaskCollector

This commit is contained in:
lzh222333
2021-03-08 16:10:16 +08:00
parent a244f87f95
commit def132e140
4 changed files with 63 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
def task_train(task_config: dict, experiment_name):
def task_train(task_config: dict, experiment_name: str):
"""
task based training
@@ -14,6 +14,8 @@ def task_train(task_config: dict, experiment_name):
----------
task_config : dict
A dict describes a task setting.
experiment_name: str
The name of experiment
"""
# model initiaiton
@@ -30,7 +32,7 @@ def task_train(task_config: dict, experiment_name):
R.save_objects(param=task_config) # keep the original format and datatype
# generate records: prediction, backtest, and analysis
records = task_config.get('record', [])
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:

View File

@@ -4,6 +4,62 @@ from typing import Union
from tqdm.auto import tqdm
class TaskCollector:
"""
Collect the record results of the finished tasks with key and filter
"""
@staticmethod
def collect(
experiment_name: str,
get_key_func,
filter_func=None,
):
"""
Parameters
----------
experiment_name : str
get_key_func : function(task: dict) -> Union[Number, str, tuple]
get the key of a task when collect it
filter_func : function(task: dict) -> bool
to judge a task will be collected or not
Returns
-------
"""
exp = R.get_exp(experiment_name=experiment_name)
# filter records
recs = exp.list_recorders()
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
params = rec.load_object("param")
if rec.status == rec.STATUS_FI:
if filter_func is None or filter_func(params):
rec.params = params
recs_flt[rid] = rec
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
group_key = get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
return reduce_group
class RollingCollector:
"""
Rolling Models Ensemble based on (R)ecord
@@ -13,7 +69,7 @@ class RollingCollector:
# TODO: speed up this class
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config
self.get_key_func = get_key_func # get the key of a task based on task config
self.flt_func = flt_func # determine whether a task can be retained based on task config
def __call__(self, exp_name) -> Union[pd.Series, dict]:

View File

@@ -79,6 +79,7 @@ class TaskGen(metaclass=abc.ABCMeta):
output: a set of tasks with different losses
"""
@abc.abstractmethod
def generate(self, task: dict) -> typing.List[dict]:
"""

View File

@@ -8,6 +8,7 @@ from qlib.log import get_module_logger
from pymongo import MongoClient
from typing import Union
def get_mongodb():
"""