1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

bug fixed and update collect.py

This commit is contained in:
lzh222333
2021-03-11 16:25:46 +00:00
parent 48f0fc147f
commit 0df88c07f6
4 changed files with 88 additions and 22 deletions

View File

@@ -1,6 +1,7 @@
from qlib.workflow import R
import pandas as pd
from typing import Union
from typing import Callable
from qlib import get_module_logger
@@ -9,9 +10,63 @@ class TaskCollector:
Collect the record results of the finished tasks with key and filter
"""
@staticmethod
def __init__(self, experiment_name: str) -> None:
self.exp_name = experiment_name
self.exp = R.get_exp(experiment_name=experiment_name)
self.logger = get_module_logger("TaskCollector")
def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False):
"""
Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter.
If you don't train with "task_train", then there is no "task.pkl" which includes the task config.
If there is a "task.pkl", then it will become rec.task which can be get simply.
Parameters
----------
rec_filter_func : Callable[[MLflowRecorder], bool], optional
judge whether you need this recorder, by default None
task_filter_func : Callable[[dict], bool], optional
judge whether you need this task, by default None
only_finished : bool, optional
whether always use finished recorder, by default True
only_have_task : bool, optional
whether it is necessary to get the task config
Returns
-------
dict
a dict of {rid:recorder}
Raises
------
OSError
if you use a task filter, but there is no "task.pkl" which includes the task config
"""
recs = self.exp.list_recorders()
# return all recorders if the filter is None and you don't need task
if rec_filter_func==None and task_filter_func==None and only_have_task==False:
return recs
recs_flt = {}
for rid, rec in recs.items():
if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False:
if rec_filter_func is None or rec_filter_func(rec):
task = None
try:
task = rec.load_object("task.pkl")
except OSError:
if task_filter_func is not None:
raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?')
if task is None and only_have_task:
continue
if task_filter_func is None or task_filter_func(task):
rec.task = task
recs_flt[rid] = rec
return recs_flt
def collect_predictions(
experiment_name: str,
self,
get_key_func,
filter_func=None,
):
@@ -27,24 +82,15 @@ class TaskCollector:
Returns
-------
dict
the dict of predictions
"""
exp = R.get_exp(experiment_name=experiment_name)
# filter records
recs = exp.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
params = rec.load_object("task.pkl")
if rec.status == rec.STATUS_FI:
if filter_func is None or filter_func(params):
rec.params = params
recs_flt[rid] = rec
recs_flt = self.list_recorders(task_filter_func=filter_func)
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
params = rec.task
group_key = get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)
@@ -57,9 +103,26 @@ class TaskCollector:
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
get_module_logger("TaskCollector").info(f"Collect {len(reduce_group)} predictions in {experiment_name}")
self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}")
return reduce_group
def collect_latest_records(
self,
filter_func=None,
):
recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True)
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
latest_record = {}
for rid, rec in recs_flt.items():
if rec.task['dataset']['kwargs']['segments']['test'] == max_test:
latest_record[rid] = rec
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
return latest_record
class RollingCollector:
"""

View File

@@ -363,7 +363,3 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
return ever_run
if __name__ == "__main__":
auto_init()
Fire(TaskManager)

View File

@@ -6,7 +6,7 @@ import pandas as pd
from qlib.utils import init_instance_by_config
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
class ModelUpdater:
"""
@@ -136,7 +136,7 @@ class ModelUpdater:
def online_filter(self, record):
tags = record.list_tags()
if tags[self.ONLINE_TAG] == self.ONLINE_TAG_TRUE:
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
return True
return False
@@ -146,6 +146,13 @@ class ModelUpdater:
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
def list_online_model(self):
"""list the record of online model
Returns
-------
dict
{rid : record of the online model}
"""
recs = self.exp.list_recorders()
online_rec = {}
for rid, rec in recs.items():