mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
bug fixed and update collect.py
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from qlib.workflow import R
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from typing import Callable
|
||||
from qlib import get_module_logger
|
||||
|
||||
|
||||
@@ -9,9 +10,63 @@ class TaskCollector:
|
||||
Collect the record results of the finished tasks with key and filter
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
self.exp_name = experiment_name
|
||||
self.exp = R.get_exp(experiment_name=experiment_name)
|
||||
self.logger = get_module_logger("TaskCollector")
|
||||
|
||||
def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False):
|
||||
"""
|
||||
Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter.
|
||||
If you don't train with "task_train", then there is no "task.pkl" which includes the task config.
|
||||
If there is a "task.pkl", then it will become rec.task which can be get simply.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rec_filter_func : Callable[[MLflowRecorder], bool], optional
|
||||
judge whether you need this recorder, by default None
|
||||
task_filter_func : Callable[[dict], bool], optional
|
||||
judge whether you need this task, by default None
|
||||
only_finished : bool, optional
|
||||
whether always use finished recorder, by default True
|
||||
only_have_task : bool, optional
|
||||
whether it is necessary to get the task config
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
a dict of {rid:recorder}
|
||||
|
||||
Raises
|
||||
------
|
||||
OSError
|
||||
if you use a task filter, but there is no "task.pkl" which includes the task config
|
||||
"""
|
||||
recs = self.exp.list_recorders()
|
||||
# return all recorders if the filter is None and you don't need task
|
||||
if rec_filter_func==None and task_filter_func==None and only_have_task==False:
|
||||
return recs
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False:
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
task = None
|
||||
try:
|
||||
task = rec.load_object("task.pkl")
|
||||
except OSError:
|
||||
if task_filter_func is not None:
|
||||
raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?')
|
||||
if task is None and only_have_task:
|
||||
continue
|
||||
|
||||
if task_filter_func is None or task_filter_func(task):
|
||||
rec.task = task
|
||||
recs_flt[rid] = rec
|
||||
|
||||
return recs_flt
|
||||
|
||||
def collect_predictions(
|
||||
experiment_name: str,
|
||||
self,
|
||||
get_key_func,
|
||||
filter_func=None,
|
||||
):
|
||||
@@ -27,24 +82,15 @@ class TaskCollector:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
dict
|
||||
the dict of predictions
|
||||
"""
|
||||
exp = R.get_exp(experiment_name=experiment_name)
|
||||
# filter records
|
||||
recs = exp.list_recorders()
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
params = rec.load_object("task.pkl")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if filter_func is None or filter_func(params):
|
||||
rec.params = params
|
||||
recs_flt[rid] = rec
|
||||
recs_flt = self.list_recorders(task_filter_func=filter_func)
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
params = rec.params
|
||||
params = rec.task
|
||||
group_key = get_key_func(params)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
@@ -57,9 +103,26 @@ class TaskCollector:
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
get_module_logger("TaskCollector").info(f"Collect {len(reduce_group)} predictions in {experiment_name}")
|
||||
self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}")
|
||||
return reduce_group
|
||||
|
||||
def collect_latest_records(
|
||||
self,
|
||||
filter_func=None,
|
||||
):
|
||||
recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True)
|
||||
|
||||
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
|
||||
|
||||
latest_record = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.task['dataset']['kwargs']['segments']['test'] == max_test:
|
||||
latest_record[rid] = rec
|
||||
|
||||
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
|
||||
return latest_record
|
||||
|
||||
|
||||
|
||||
class RollingCollector:
|
||||
"""
|
||||
|
||||
@@ -363,7 +363,3 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
Fire(TaskManager)
|
||||
|
||||
@@ -6,7 +6,7 @@ import pandas as pd
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
|
||||
from qlib.model.trainer import task_train
|
||||
|
||||
class ModelUpdater:
|
||||
"""
|
||||
@@ -136,7 +136,7 @@ class ModelUpdater:
|
||||
|
||||
def online_filter(self, record):
|
||||
tags = record.list_tags()
|
||||
if tags[self.ONLINE_TAG] == self.ONLINE_TAG_TRUE:
|
||||
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -146,6 +146,13 @@ class ModelUpdater:
|
||||
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
|
||||
|
||||
def list_online_model(self):
|
||||
"""list the record of online model
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{rid : record of the online model}
|
||||
"""
|
||||
recs = self.exp.list_recorders()
|
||||
online_rec = {}
|
||||
for rid, rec in recs.items():
|
||||
|
||||
Reference in New Issue
Block a user