1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

first version of online serving

This commit is contained in:
lzh222333
2021-03-11 03:00:30 +00:00
parent 2ca2071d95
commit 48f0fc147f
4 changed files with 234 additions and 3 deletions

View File

@@ -0,0 +1,77 @@
import qlib
from qlib.model.trainer import task_train
from qlib.workflow.task.update import ModelUpdater
from qlib.config import REG_CN
import fire
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp",},
}
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
def first_train(experiment_name="online_svr"):
qlib.init(provider_uri=provider_uri, region=REG_CN)
model_updater = ModelUpdater(experiment_name)
rid = task_train(task_config=task, experiment_name=experiment_name)
model_updater.reset_online_model(rid)
def update_online_pred(experiment_name="online_svr"):
qlib.init(provider_uri=provider_uri, region=REG_CN)
model_updater = ModelUpdater(experiment_name)
print("Here are the online models waiting for update:")
for rid, rec in model_updater.list_online_model().items():
print(rid)
model_updater.update_online_pred()
if __name__ == '__main__':
fire.Fire()
# to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
# to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred

View File

@@ -53,4 +53,4 @@ def task_train(task_config: dict, experiment_name: str) -> str:
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
return record.info["id"]
return recorder.info["id"]

View File

@@ -11,8 +11,8 @@ class TaskCollector:
@staticmethod
def collect_predictions(
experiment_name: str,
get_key_func,
experiment_name: str,
get_key_func,
filter_func=None,
):
"""

View File

@@ -0,0 +1,154 @@
from typing import Union
from qlib.workflow import R
from tqdm.auto import tqdm
from qlib.data import D
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib import get_module_logger
from qlib.workflow import R
class ModelUpdater:
"""
The model updater to re-train model or update predictions
"""
ONLINE_TAG = "online_model"
ONLINE_TAG_TRUE = "True"
ONLINE_TAG_FALSE = "False"
def __init__(self, experiment_name: str) -> None:
"""ModelUpdater needs experiment name to find the records
Parameters
----------
experiment_name : str
experiment name string
"""
self.exp_name = experiment_name
self.exp = R.get_exp(experiment_name=experiment_name)
self.logger = get_module_logger("ModelUpdater")
def set_online_model(self, rid: str):
"""online model will be identified at the tags of the record
Parameters
----------
rid : str
the id of a record
"""
rec = self.exp.get_recorder(recorder_id=rid)
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE})
def cancel_online_model(self, rid: str):
rec = self.exp.get_recorder(recorder_id=rid)
rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE})
def cancel_all_online_model(self):
recs = self.exp.list_recorders()
for rid, rec in recs.items():
self.cancel_online_model(rid)
def reset_online_model(self, rids: Union[str, list]):
"""cancel all online model and reset the given model to online model
Parameters
----------
rids : Union[str, list]
the name of a record or the list of the name of records
"""
self.cancel_all_online_model()
if isinstance(rids, str):
rids = [rids]
for rid in rids:
self.set_online_model(rid)
def update_pred(self, rid: str):
"""update predictions to the latest day in Calendar based on rid
Parameters
----------
rid : str
the id of the record
"""
rec = self.exp.get_recorder(recorder_id=rid)
old_pred = rec.load_object("pred.pkl")
last_end = old_pred.index.get_level_values("datetime").max()
task_config = rec.load_object("task.pkl")
# updated to the latest trading day
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
if len(cal) == 0:
self.logger.info(f"All prediction in {rid} of {self.exp_name} are latest. No need to update.")
return
start_time, end_time = cal[0], cal[-1]
task_config["dataset"]["kwargs"]["segments"]["test"] = (start_time, end_time)
task_config["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = end_time
dataset = init_instance_by_config(task_config["dataset"])
model = rec.load_object("params.pkl")
new_pred = model.predict(dataset)
cb_pred = pd.concat([old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
rec.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {rid} of {self.exp_name}.")
def update_all_pred(self, filter_func=None):
"""update all predictions in this experiment after filter.
An example of filter function:
.. code-block:: python
def record_filter(record):
task_config = record.load_object("task.pkl")
if task_config["model"]["class"]=="LGBModel":
return True
return False
Parameters
----------
filter_func : function, optional
the filter function to decide whether this record will be updated, by default None
Returns
----------
cnt: int
the count of updated record
"""
cnt = 0
recs = self.exp.list_recorders()
for rid, rec in recs.items():
if rec.status == rec.STATUS_FI:
if filter_func != None and filter_func(rec) == False:
# records that should be filtered out
continue
self.update_pred(rid)
cnt += 1
return cnt
def online_filter(self, record):
tags = record.list_tags()
if tags[self.ONLINE_TAG] == self.ONLINE_TAG_TRUE:
return True
return False
def update_online_pred(self):
"""update all online model predictions to the latest day in Calendar."""
cnt = self.update_all_pred(self.online_filter)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
def list_online_model(self):
recs = self.exp.list_recorders()
online_rec = {}
for rid, rec in recs.items():
if self.online_filter(rec):
online_rec[rid] = rec
return online_rec