mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
modified format and added TaskCollector
This commit is contained in:
@@ -6,7 +6,7 @@ from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name):
|
||||
def task_train(task_config: dict, experiment_name: str):
|
||||
"""
|
||||
task based training
|
||||
|
||||
@@ -14,6 +14,8 @@ def task_train(task_config: dict, experiment_name):
|
||||
----------
|
||||
task_config : dict
|
||||
A dict describes a task setting.
|
||||
experiment_name: str
|
||||
The name of experiment
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
@@ -30,7 +32,7 @@ def task_train(task_config: dict, experiment_name):
|
||||
R.save_objects(param=task_config) # keep the original format and datatype
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get('record', [])
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
|
||||
@@ -4,6 +4,62 @@ from typing import Union
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class TaskCollector:
|
||||
"""
|
||||
Collect the record results of the finished tasks with key and filter
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def collect(
|
||||
experiment_name: str,
|
||||
get_key_func,
|
||||
filter_func=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
get_key_func : function(task: dict) -> Union[Number, str, tuple]
|
||||
get the key of a task when collect it
|
||||
filter_func : function(task: dict) -> bool
|
||||
to judge a task will be collected or not
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
exp = R.get_exp(experiment_name=experiment_name)
|
||||
# filter records
|
||||
recs = exp.list_recorders()
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in tqdm(recs.items(), desc="Loading data"):
|
||||
params = rec.load_object("param")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if filter_func is None or filter_func(params):
|
||||
rec.params = params
|
||||
recs_flt[rid] = rec
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
params = rec.params
|
||||
group_key = get_key_func(params)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
# reduce group
|
||||
reduce_group = {}
|
||||
for k, rec_l in recs_group.items():
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
return reduce_group
|
||||
|
||||
|
||||
class RollingCollector:
|
||||
"""
|
||||
Rolling Models Ensemble based on (R)ecord
|
||||
@@ -13,7 +69,7 @@ class RollingCollector:
|
||||
|
||||
# TODO: speed up this class
|
||||
def __init__(self, get_key_func, flt_func=None):
|
||||
self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config
|
||||
self.get_key_func = get_key_func # get the key of a task based on task config
|
||||
self.flt_func = flt_func # determine whether a task can be retained based on task config
|
||||
|
||||
def __call__(self, exp_name) -> Union[pd.Series, dict]:
|
||||
|
||||
@@ -79,6 +79,7 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
output: a set of tasks with different losses
|
||||
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self, task: dict) -> typing.List[dict]:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_mongodb():
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user