mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
the second version of online serving
This commit is contained in:
@@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
|
||||
@@ -2,6 +2,7 @@ from qlib.workflow import R
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from typing import Callable
|
||||
|
||||
from qlib import get_module_logger
|
||||
|
||||
|
||||
@@ -17,13 +18,13 @@ class TaskCollector:
|
||||
|
||||
def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False):
|
||||
"""
|
||||
Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter.
|
||||
If you don't train with "task_train", then there is no "task.pkl" which includes the task config.
|
||||
If there is a "task.pkl", then it will become rec.task which can be get simply.
|
||||
Return a dict of {rid:Recorder} by recorder filter and task filter. It is not necessary to use those filter.
|
||||
If you don't train with "task_train", then there is no "task" which includes the task config.
|
||||
If there is a "task", then it will become rec.task which can be get simply.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rec_filter_func : Callable[[MLflowRecorder], bool], optional
|
||||
rec_filter_func : Callable[[Recorder], bool], optional
|
||||
judge whether you need this recorder, by default None
|
||||
task_filter_func : Callable[[dict], bool], optional
|
||||
judge whether you need this task, by default None
|
||||
@@ -35,30 +36,27 @@ class TaskCollector:
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
a dict of {rid:recorder}
|
||||
a dict of {rid:Recorder}
|
||||
|
||||
Raises
|
||||
------
|
||||
OSError
|
||||
if you use a task filter, but there is no "task.pkl" which includes the task config
|
||||
if you use a task filter, but there is no "task" which includes the task config
|
||||
"""
|
||||
recs = self.exp.list_recorders()
|
||||
# return all recorders if the filter is None and you don't need task
|
||||
if rec_filter_func==None and task_filter_func==None and only_have_task==False:
|
||||
return recs
|
||||
recs_flt = {}
|
||||
if task_filter_func is not None:
|
||||
only_have_task = True
|
||||
for rid, rec in recs.items():
|
||||
if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False:
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
task = None
|
||||
try:
|
||||
task = rec.load_object("task.pkl")
|
||||
task = rec.load_object("task")
|
||||
except OSError:
|
||||
if task_filter_func is not None:
|
||||
raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?')
|
||||
pass
|
||||
if task is None and only_have_task:
|
||||
continue
|
||||
|
||||
if task_filter_func is None or task_filter_func(task):
|
||||
rec.task = task
|
||||
recs_flt[rid] = rec
|
||||
@@ -68,7 +66,7 @@ class TaskCollector:
|
||||
def collect_predictions(
|
||||
self,
|
||||
get_key_func,
|
||||
filter_func=None,
|
||||
task_filter_func=None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -85,7 +83,7 @@ class TaskCollector:
|
||||
dict
|
||||
the dict of predictions
|
||||
"""
|
||||
recs_flt = self.list_recorders(task_filter_func=filter_func)
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
@@ -108,11 +106,14 @@ class TaskCollector:
|
||||
|
||||
def collect_latest_records(
|
||||
self,
|
||||
filter_func=None,
|
||||
task_filter_func=None,
|
||||
):
|
||||
recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True)
|
||||
|
||||
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
|
||||
|
||||
if len(recs_flt) == 0:
|
||||
self.logger.warning("Can not collect any recorders...")
|
||||
return None, None
|
||||
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
|
||||
|
||||
latest_record = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
@@ -120,52 +121,5 @@ class TaskCollector:
|
||||
latest_record[rid] = rec
|
||||
|
||||
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
|
||||
return latest_record
|
||||
|
||||
|
||||
|
||||
class RollingCollector:
|
||||
"""
|
||||
Rolling Models Ensemble based on (R)ecord
|
||||
|
||||
This shares nothing with Ensemble
|
||||
"""
|
||||
|
||||
# TODO: speed up this class
|
||||
def __init__(self, get_key_func, flt_func=None):
|
||||
self.get_key_func = get_key_func # get the key of a task based on task config
|
||||
self.flt_func = flt_func # determine whether a task can be retained based on task config
|
||||
|
||||
def __call__(self, exp_name) -> Union[pd.Series, dict]:
|
||||
# TODO;
|
||||
# Should we split the scripts into several sub functions?
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
|
||||
# filter records
|
||||
recs = exp.list_recorders()
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in tqdm(recs.items(), desc="Loading data"):
|
||||
params = rec.load_object("task.pkl")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if self.flt_func is None or self.flt_func(params):
|
||||
rec.params = params
|
||||
recs_flt[rid] = rec
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
params = rec.params
|
||||
group_key = self.get_key_func(params)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
# reduce group
|
||||
reduce_group = {}
|
||||
for k, rec_l in recs_group.items():
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
return reduce_group
|
||||
return latest_record, max_test
|
||||
|
||||
@@ -10,10 +10,8 @@ A task consists of 3 parts
|
||||
from bson.binary import Binary
|
||||
import pickle
|
||||
from pymongo.errors import InvalidDocument
|
||||
from fire import Fire
|
||||
from bson.objectid import ObjectId
|
||||
from contextlib import contextmanager
|
||||
from loguru import logger
|
||||
from tqdm.cli import tqdm
|
||||
import time
|
||||
import concurrent
|
||||
@@ -21,7 +19,7 @@ import pymongo
|
||||
from qlib.config import C
|
||||
from .utils import get_mongodb
|
||||
from qlib import auto_init
|
||||
|
||||
from qlib import get_module_logger
|
||||
|
||||
class TaskManager:
|
||||
"""TaskManager
|
||||
@@ -62,6 +60,7 @@ class TaskManager:
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = task_pool
|
||||
self.logger = get_module_logger("TaskManager")
|
||||
|
||||
def list(self):
|
||||
return self.mdb.list_collection_names()
|
||||
@@ -210,9 +209,9 @@ class TaskManager:
|
||||
yield task
|
||||
except Exception:
|
||||
if task is not None:
|
||||
logger.info("Returning task before raising error")
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
logger.info("Task returned")
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}, task_pool=None):
|
||||
@@ -352,7 +351,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
with tm.safe_fetch_task() as task:
|
||||
if task is None:
|
||||
break
|
||||
logger.info(task["def"])
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Union,List
|
||||
from qlib.workflow import R
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.data import D
|
||||
@@ -7,8 +7,10 @@ from qlib.utils import init_instance_by_config
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
|
||||
class ModelUpdater:
|
||||
class ModelUpdater(TaskCollector):
|
||||
"""
|
||||
The model updater to re-train model or update predictions
|
||||
"""
|
||||
@@ -29,58 +31,59 @@ class ModelUpdater:
|
||||
self.exp = R.get_exp(experiment_name=experiment_name)
|
||||
self.logger = get_module_logger("ModelUpdater")
|
||||
|
||||
def set_online_model(self, rid: str):
|
||||
def set_online_model(self, recorder: Union[str,Recorder]):
|
||||
"""online model will be identified at the tags of the record
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rid : str
|
||||
the id of a record
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
rec = self.exp.get_recorder(recorder_id=rid)
|
||||
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE})
|
||||
if isinstance(recorder,str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_TRUE})
|
||||
|
||||
def cancel_online_model(self, rid: str):
|
||||
rec = self.exp.get_recorder(recorder_id=rid)
|
||||
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE})
|
||||
def cancel_online_model(self, recorder: Union[str,Recorder]):
|
||||
if isinstance(recorder,str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_FALSE})
|
||||
|
||||
def cancel_all_online_model(self):
|
||||
recs = self.exp.list_recorders()
|
||||
for rid, rec in recs.items():
|
||||
self.cancel_online_model(rid)
|
||||
self.cancel_online_model(rec)
|
||||
|
||||
def reset_online_model(self, rids: Union[str, list]):
|
||||
def reset_online_model(self, recorders: List[Union[str,Recorder]]):
|
||||
"""cancel all online model and reset the given model to online model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rids : Union[str, list]
|
||||
the name of a record or the list of the name of records
|
||||
recorders: List[Union[str,Recorder]]
|
||||
the list of the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
self.cancel_all_online_model()
|
||||
if isinstance(rids, str):
|
||||
rids = [rids]
|
||||
for rid in rids:
|
||||
self.set_online_model(rid)
|
||||
for rec_or_rid in recorders:
|
||||
self.set_online_model(rec_or_rid)
|
||||
|
||||
def update_pred(self, rid: str):
|
||||
def update_pred(self, recorder: Union[str,Recorder]):
|
||||
"""update predictions to the latest day in Calendar based on rid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rid : str
|
||||
the id of the record
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
rec = self.exp.get_recorder(recorder_id=rid)
|
||||
old_pred = rec.load_object("pred.pkl")
|
||||
if isinstance(recorder,str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
old_pred = recorder.load_object("pred.pkl")
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
task_config = rec.load_object("task.pkl")
|
||||
task_config = recorder.load_object("task") # recorder.task
|
||||
|
||||
# updated to the latest trading day
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
|
||||
if len(cal) == 0:
|
||||
self.logger.info(f"All prediction in {rid} of {self.exp_name} are latest. No need to update.")
|
||||
self.logger.info(f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update.")
|
||||
return
|
||||
|
||||
start_time, end_time = cal[0], cal[-1]
|
||||
@@ -89,32 +92,32 @@ class ModelUpdater:
|
||||
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
|
||||
model = rec.load_object("params.pkl")
|
||||
model = recorder.load_object("params.pkl")
|
||||
new_pred = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
rec.save_objects(**{"pred.pkl": cb_pred})
|
||||
recorder.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {rid} of {self.exp_name}.")
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}.")
|
||||
|
||||
def update_all_pred(self, filter_func=None):
|
||||
def update_all_pred(self, rec_filter_func=None):
|
||||
"""update all predictions in this experiment after filter.
|
||||
|
||||
An example of filter function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def record_filter(record):
|
||||
task_config = record.load_object("task.pkl")
|
||||
def rec_filter_func(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
if task_config["model"]["class"]=="LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_func : function, optional
|
||||
rec_filter_func : Callable[[Recorder], bool], optional
|
||||
the filter function to decide whether this record will be updated, by default None
|
||||
|
||||
Returns
|
||||
@@ -123,20 +126,14 @@ class ModelUpdater:
|
||||
the count of updated record
|
||||
|
||||
"""
|
||||
cnt = 0
|
||||
recs = self.exp.list_recorders()
|
||||
recs = self.list_recorders(rec_filter_func=rec_filter_func,only_have_task=True)
|
||||
for rid, rec in recs.items():
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if filter_func != None and filter_func(rec) == False:
|
||||
# records that should be filtered out
|
||||
continue
|
||||
self.update_pred(rid)
|
||||
cnt += 1
|
||||
return cnt
|
||||
self.update_pred(rec)
|
||||
return len(recs)
|
||||
|
||||
def online_filter(self, record):
|
||||
tags = record.list_tags()
|
||||
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
|
||||
def online_filter(self, recorder):
|
||||
tags = recorder.list_tags()
|
||||
if tags.get(ModelUpdater.ONLINE_TAG, ModelUpdater.ONLINE_TAG_FALSE) == ModelUpdater.ONLINE_TAG_TRUE:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -151,11 +148,7 @@ class ModelUpdater:
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{rid : record of the online model}
|
||||
{rid : recorder of the online model}
|
||||
"""
|
||||
recs = self.exp.list_recorders()
|
||||
online_rec = {}
|
||||
for rid, rec in recs.items():
|
||||
if self.online_filter(rec):
|
||||
online_rec[rid] = rec
|
||||
return online_rec
|
||||
|
||||
return self.list_recorders(rec_filter_func=self.online_filter)
|
||||
|
||||
@@ -50,7 +50,6 @@ class TimeAdjuster:
|
||||
if idx >= len(self.cals):
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
(Deprecated)
|
||||
@@ -86,6 +85,9 @@ class TimeAdjuster:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return idx
|
||||
|
||||
def cal_interval(self, time_point_A, time_point_B):
|
||||
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
"""
|
||||
Align time_point to trade date of calendar
|
||||
|
||||
Reference in New Issue
Block a user