1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

the second version of online serving

This commit is contained in:
lzh222333
2021-03-12 08:04:08 +00:00
parent 0df88c07f6
commit 6d8aa215d6
5 changed files with 75 additions and 127 deletions

View File

@@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype
R.save_objects(**{"task": task_config}) # keep the original format and datatype
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])

View File

@@ -2,6 +2,7 @@ from qlib.workflow import R
import pandas as pd
from typing import Union
from typing import Callable
from qlib import get_module_logger
@@ -17,13 +18,13 @@ class TaskCollector:
def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False):
"""
Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter.
If you don't train with "task_train", then there is no "task.pkl" which includes the task config.
If there is a "task.pkl", then it will become rec.task which can be get simply.
Return a dict of {rid:Recorder} by recorder filter and task filter. It is not necessary to use those filter.
If you don't train with "task_train", then there is no "task" which includes the task config.
If there is a "task", then it will become rec.task which can be get simply.
Parameters
----------
rec_filter_func : Callable[[MLflowRecorder], bool], optional
rec_filter_func : Callable[[Recorder], bool], optional
judge whether you need this recorder, by default None
task_filter_func : Callable[[dict], bool], optional
judge whether you need this task, by default None
@@ -35,30 +36,27 @@ class TaskCollector:
Returns
-------
dict
a dict of {rid:recorder}
a dict of {rid:Recorder}
Raises
------
OSError
if you use a task filter, but there is no "task.pkl" which includes the task config
if you use a task filter, but there is no "task" which includes the task config
"""
recs = self.exp.list_recorders()
# return all recorders if the filter is None and you don't need task
if rec_filter_func==None and task_filter_func==None and only_have_task==False:
return recs
recs_flt = {}
if task_filter_func is not None:
only_have_task = True
for rid, rec in recs.items():
if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False:
if rec_filter_func is None or rec_filter_func(rec):
task = None
try:
task = rec.load_object("task.pkl")
task = rec.load_object("task")
except OSError:
if task_filter_func is not None:
raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?')
pass
if task is None and only_have_task:
continue
if task_filter_func is None or task_filter_func(task):
rec.task = task
recs_flt[rid] = rec
@@ -68,7 +66,7 @@ class TaskCollector:
def collect_predictions(
self,
get_key_func,
filter_func=None,
task_filter_func=None,
):
"""
@@ -85,7 +83,7 @@ class TaskCollector:
dict
the dict of predictions
"""
recs_flt = self.list_recorders(task_filter_func=filter_func)
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
# group
recs_group = {}
@@ -108,11 +106,14 @@ class TaskCollector:
def collect_latest_records(
self,
filter_func=None,
task_filter_func=None,
):
recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True)
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
if len(recs_flt) == 0:
self.logger.warning("Can not collect any recorders...")
return None, None
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
latest_record = {}
for rid, rec in recs_flt.items():
@@ -120,52 +121,5 @@ class TaskCollector:
latest_record[rid] = rec
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
return latest_record
class RollingCollector:
"""
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
"""
# TODO: speed up this class
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func # get the key of a task based on task config
self.flt_func = flt_func # determine whether a task can be retained based on task config
def __call__(self, exp_name) -> Union[pd.Series, dict]:
# TODO;
# Should we split the scripts into several sub functions?
exp = R.get_exp(experiment_name=exp_name)
# filter records
recs = exp.list_recorders()
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
params = rec.load_object("task.pkl")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):
rec.params = params
recs_flt[rid] = rec
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
group_key = self.get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
return reduce_group
return latest_record, max_test

View File

@@ -10,10 +10,8 @@ A task consists of 3 parts
from bson.binary import Binary
import pickle
from pymongo.errors import InvalidDocument
from fire import Fire
from bson.objectid import ObjectId
from contextlib import contextmanager
from loguru import logger
from tqdm.cli import tqdm
import time
import concurrent
@@ -21,7 +19,7 @@ import pymongo
from qlib.config import C
from .utils import get_mongodb
from qlib import auto_init
from qlib import get_module_logger
class TaskManager:
"""TaskManager
@@ -62,6 +60,7 @@ class TaskManager:
"""
self.mdb = get_mongodb()
self.task_pool = task_pool
self.logger = get_module_logger("TaskManager")
def list(self):
return self.mdb.list_collection_names()
@@ -210,9 +209,9 @@ class TaskManager:
yield task
except Exception:
if task is not None:
logger.info("Returning task before raising error")
self.logger.info("Returning task before raising error")
self.return_task(task)
logger.info("Task returned")
self.logger.info("Task returned")
raise
def task_fetcher_iter(self, query={}, task_pool=None):
@@ -352,7 +351,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
with tm.safe_fetch_task() as task:
if task is None:
break
logger.info(task["def"])
get_module_logger("run_task").info(task["def"])
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, task["def"], *args, **kwargs).result()

View File

@@ -1,4 +1,4 @@
from typing import Union
from typing import Union,List
from qlib.workflow import R
from tqdm.auto import tqdm
from qlib.data import D
@@ -7,8 +7,10 @@ from qlib.utils import init_instance_by_config
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import TaskCollector
class ModelUpdater:
class ModelUpdater(TaskCollector):
"""
The model updater to re-train model or update predictions
"""
@@ -29,58 +31,59 @@ class ModelUpdater:
self.exp = R.get_exp(experiment_name=experiment_name)
self.logger = get_module_logger("ModelUpdater")
def set_online_model(self, rid: str):
def set_online_model(self, recorder: Union[str,Recorder]):
"""online model will be identified at the tags of the record
Parameters
----------
rid : str
the id of a record
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
rec = self.exp.get_recorder(recorder_id=rid)
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE})
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_TRUE})
def cancel_online_model(self, rid: str):
rec = self.exp.get_recorder(recorder_id=rid)
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE})
def cancel_online_model(self, recorder: Union[str,Recorder]):
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_FALSE})
def cancel_all_online_model(self):
recs = self.exp.list_recorders()
for rid, rec in recs.items():
self.cancel_online_model(rid)
self.cancel_online_model(rec)
def reset_online_model(self, rids: Union[str, list]):
def reset_online_model(self, recorders: List[Union[str,Recorder]]):
"""cancel all online model and reset the given model to online model
Parameters
----------
rids : Union[str, list]
the name of a record or the list of the name of records
recorders: List[Union[str,Recorder]]
the list of the id of a Recorder or the Recorder instance
"""
self.cancel_all_online_model()
if isinstance(rids, str):
rids = [rids]
for rid in rids:
self.set_online_model(rid)
for rec_or_rid in recorders:
self.set_online_model(rec_or_rid)
def update_pred(self, rid: str):
def update_pred(self, recorder: Union[str,Recorder]):
"""update predictions to the latest day in Calendar based on rid
Parameters
----------
rid : str
the id of the record
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
rec = self.exp.get_recorder(recorder_id=rid)
old_pred = rec.load_object("pred.pkl")
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
old_pred = recorder.load_object("pred.pkl")
last_end = old_pred.index.get_level_values("datetime").max()
task_config = rec.load_object("task.pkl")
task_config = recorder.load_object("task") # recorder.task
# updated to the latest trading day
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
if len(cal) == 0:
self.logger.info(f"All prediction in {rid} of {self.exp_name} are latest. No need to update.")
self.logger.info(f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update.")
return
start_time, end_time = cal[0], cal[-1]
@@ -89,32 +92,32 @@ class ModelUpdater:
dataset = init_instance_by_config(task_config["dataset"])
model = rec.load_object("params.pkl")
model = recorder.load_object("params.pkl")
new_pred = model.predict(dataset)
cb_pred = pd.concat([old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
rec.save_objects(**{"pred.pkl": cb_pred})
recorder.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {rid} of {self.exp_name}.")
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}.")
def update_all_pred(self, filter_func=None):
def update_all_pred(self, rec_filter_func=None):
"""update all predictions in this experiment after filter.
An example of filter function:
.. code-block:: python
def record_filter(record):
task_config = record.load_object("task.pkl")
def rec_filter_func(recorder):
task_config = recorder.load_object("task")
if task_config["model"]["class"]=="LGBModel":
return True
return False
Parameters
----------
filter_func : function, optional
rec_filter_func : Callable[[Recorder], bool], optional
the filter function to decide whether this record will be updated, by default None
Returns
@@ -123,20 +126,14 @@ class ModelUpdater:
the count of updated record
"""
cnt = 0
recs = self.exp.list_recorders()
recs = self.list_recorders(rec_filter_func=rec_filter_func,only_have_task=True)
for rid, rec in recs.items():
if rec.status == rec.STATUS_FI:
if filter_func != None and filter_func(rec) == False:
# records that should be filtered out
continue
self.update_pred(rid)
cnt += 1
return cnt
self.update_pred(rec)
return len(recs)
def online_filter(self, record):
tags = record.list_tags()
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
def online_filter(self, recorder):
tags = recorder.list_tags()
if tags.get(ModelUpdater.ONLINE_TAG, ModelUpdater.ONLINE_TAG_FALSE) == ModelUpdater.ONLINE_TAG_TRUE:
return True
return False
@@ -151,11 +148,7 @@ class ModelUpdater:
Returns
-------
dict
{rid : record of the online model}
{rid : recorder of the online model}
"""
recs = self.exp.list_recorders()
online_rec = {}
for rid, rec in recs.items():
if self.online_filter(rec):
online_rec[rid] = rec
return online_rec
return self.list_recorders(rec_filter_func=self.online_filter)

View File

@@ -50,7 +50,6 @@ class TimeAdjuster:
if idx >= len(self.cals):
return None
return self.cals[idx]
def max(self):
"""
(Deprecated)
@@ -86,6 +85,9 @@ class TimeAdjuster:
raise NotImplementedError(f"This type of input is not supported")
return idx
def cal_interval(self, time_point_A, time_point_B):
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start"):
"""
Align time_point to trade date of calendar