1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00
Files
qlib/qlib/workflow/task/update.py
2021-03-12 08:04:08 +00:00

155 lines
5.4 KiB
Python

from typing import Union,List
from qlib.workflow import R
from tqdm.auto import tqdm
from qlib.data import D
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import TaskCollector
class ModelUpdater(TaskCollector):
"""
The model updater to re-train model or update predictions
"""
ONLINE_TAG = "online_model"
ONLINE_TAG_TRUE = "True"
ONLINE_TAG_FALSE = "False"
def __init__(self, experiment_name: str) -> None:
"""ModelUpdater needs experiment name to find the records
Parameters
----------
experiment_name : str
experiment name string
"""
self.exp_name = experiment_name
self.exp = R.get_exp(experiment_name=experiment_name)
self.logger = get_module_logger("ModelUpdater")
def set_online_model(self, recorder: Union[str,Recorder]):
"""online model will be identified at the tags of the record
Parameters
----------
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_TRUE})
def cancel_online_model(self, recorder: Union[str,Recorder]):
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_FALSE})
def cancel_all_online_model(self):
recs = self.exp.list_recorders()
for rid, rec in recs.items():
self.cancel_online_model(rec)
def reset_online_model(self, recorders: List[Union[str,Recorder]]):
"""cancel all online model and reset the given model to online model
Parameters
----------
recorders: List[Union[str,Recorder]]
the list of the id of a Recorder or the Recorder instance
"""
self.cancel_all_online_model()
for rec_or_rid in recorders:
self.set_online_model(rec_or_rid)
def update_pred(self, recorder: Union[str,Recorder]):
"""update predictions to the latest day in Calendar based on rid
Parameters
----------
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
if isinstance(recorder,str):
recorder = self.exp.get_recorder(recorder_id=recorder)
old_pred = recorder.load_object("pred.pkl")
last_end = old_pred.index.get_level_values("datetime").max()
task_config = recorder.load_object("task") # recorder.task
# updated to the latest trading day
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
if len(cal) == 0:
self.logger.info(f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update.")
return
start_time, end_time = cal[0], cal[-1]
task_config["dataset"]["kwargs"]["segments"]["test"] = (start_time, end_time)
task_config["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = end_time
dataset = init_instance_by_config(task_config["dataset"])
model = recorder.load_object("params.pkl")
new_pred = model.predict(dataset)
cb_pred = pd.concat([old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
recorder.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}.")
def update_all_pred(self, rec_filter_func=None):
"""update all predictions in this experiment after filter.
An example of filter function:
.. code-block:: python
def rec_filter_func(recorder):
task_config = recorder.load_object("task")
if task_config["model"]["class"]=="LGBModel":
return True
return False
Parameters
----------
rec_filter_func : Callable[[Recorder], bool], optional
the filter function to decide whether this record will be updated, by default None
Returns
----------
cnt: int
the count of updated record
"""
recs = self.list_recorders(rec_filter_func=rec_filter_func,only_have_task=True)
for rid, rec in recs.items():
self.update_pred(rec)
return len(recs)
def online_filter(self, recorder):
tags = recorder.list_tags()
if tags.get(ModelUpdater.ONLINE_TAG, ModelUpdater.ONLINE_TAG_FALSE) == ModelUpdater.ONLINE_TAG_TRUE:
return True
return False
def update_online_pred(self):
"""update all online model predictions to the latest day in Calendar."""
cnt = self.update_all_pred(self.online_filter)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
def list_online_model(self):
"""list the record of online model
Returns
-------
dict
{rid : recorder of the online model}
"""
return self.list_recorders(rec_filter_func=self.online_filter)