1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

online serving V8

This commit is contained in:
lzh222333
2021-04-25 06:26:45 +00:00
parent de0a0c083d
commit 319396c815
8 changed files with 270 additions and 197 deletions

View File

@@ -15,6 +15,11 @@ from qlib.workflow.task.utils import list_recorders
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
"""
This example shows how a Trainer work based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be showed in task_collecting.
"""
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
@@ -71,81 +76,83 @@ task_xgboost_config = {
"record": record_config,
}
# Reset all things to the first status, be careful to save important data
def reset(task_pool, exp_name):
print("========== reset ==========")
TaskManager(task_pool=task_pool).remove()
exp = R.get_exp(experiment_name=exp_name)
class RollingTaskExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region=REG_CN,
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=self.task_config,
generators=self.rolling_gen, # generate different date segments
)
pprint(tasks)
return tasks
# This part corresponds to "Task Generating" in the document
def task_generating():
def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
print("========== task_generating ==========")
def task_collecting(self):
print("========== task_collecting ==========")
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
)
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
pprint(tasks)
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
return tasks
artifact = ens_workflow(
RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RollingGroup(),
)
print(artifact)
def task_training(tasks, task_pool, exp_name):
trainer = TrainerRM(exp_name, task_pool)
trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
def task_collecting(exp_name):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RollingGroup(),
)
print(artifact)
def main(
provider_uri="~/.qlib/qlib_data/cn_data",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
):
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
reset(task_pool, experiment_name)
tasks = task_generating()
task_training(tasks, task_pool, experiment_name)
task_collecting(experiment_name)
def main(self):
self.reset()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()
# python task_manager_rolling.py main --experiment_name="your_exp_name"
fire.Fire(RollingTaskExample)

View File

@@ -11,7 +11,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
"""
This examples is about the OnlineManager and OnlineSimulator based on Rolling tasks.
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
The OnlineManager will focus on the updating of your online models.
The OnlineSimulator will focus on the simulating real updating routine of your online models.
"""

View File

@@ -1,18 +1,21 @@
from pprint import pprint
import os
from pathlib import Path
import pickle
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow import R
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.task.utils import list_recorders
from qlib.model.trainer import TrainerRM
from qlib.model.ens.group import RollingGroup
"""
This example show how RollingOnlineManager works with rolling tasks.
There are two parts including first train and routine.
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
"""
data_handler_config = {
"start_time": "2013-01-01",
@@ -89,92 +92,38 @@ class RollingOnlineExample:
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
self.trainer = TrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool)
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer
experiment_name=exp_name,
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
trainer=TrainerRM(self.exp_name, self.task_pool),
)
def print_online_model(self):
print("========== print_online_model ==========")
print("Current 'online' model:")
for rec in self.rolling_online_manager.online_models():
print(rec.info["id"])
print("Current 'next online' model:")
for rid, rec in list_recorders(self.exp_name).items():
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
# This part corresponds to "Task Generating" in the document
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=self.rolling_gen, # generate different date segment
)
pprint(tasks)
return tasks
def task_training(self, tasks):
# self.trainer.train(tasks)
self.rolling_online_manager.prepare_new_models(tasks, tag=RollingOnlineManager.ONLINE_TAG)
# This part corresponds to "Task Collecting" in the document
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
artifact = ens_workflow(
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
)
print(artifact)
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this firstly to see the workflow in Task Management
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== first_run ==========")
self.reset()
tasks = self.task_generating()
pprint(tasks)
self.task_training(tasks)
self.task_collecting()
# latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
# self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
print(self.rolling_online_manager.collect_artifact())
def routine(self):
print("========== routine ==========")
self.print_online_model()
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
self.rolling_online_manager = pickle.load(f)
self.rolling_online_manager.routine()
self.print_online_model()
self.task_collecting()
print(self.rolling_online_manager.collect_artifact())
def main(self):
self.first_run()

View File

@@ -5,6 +5,13 @@ from qlib.model.trainer import task_train
from qlib.workflow.online.manager import OnlineManagerR
from qlib.workflow.task.utils import list_recorders
"""
This example show how OnlineManager works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
Next, the RollingOnlineManager will finish updating online prediction
"""
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
@@ -52,31 +59,25 @@ task = {
}
def first_train(experiment_name="online_srv"):
class UpdatePredExample:
def __init__(
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_manager = OnlineManagerR(self.experiment_name)
self.task_config = task_config
rec = task_train(task_config=task, experiment_name=experiment_name)
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_manager.reset_online_tag(rec) # set to online model
online_manager = OnlineManagerR(experiment_name)
online_manager.reset_online_tag(rec)
def update_online_pred(self):
self.online_manager.update_online_pred()
def update_online_pred(experiment_name="online_srv"):
online_manager = OnlineManagerR(experiment_name)
print("Here are the online models waiting for update:")
for rid, rec in list_recorders(experiment_name).items():
if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG:
print(rid)
online_manager.update_online_pred()
def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=region)
first_train(experiment_name)
update_online_pred(experiment_name)
def main(self):
self.first_train()
self.update_online_pred()
if __name__ == "__main__":
@@ -86,4 +87,4 @@ if __name__ == "__main__":
# python update_online_pred.py update_online_pred
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()
fire.Fire(UpdatePredExample)

View File

@@ -135,3 +135,12 @@ class TrainerRM(Trainer):
for _id in _id_list:
recs.append(tm.re_query(_id)["res"])
return recs
class DelayTrainer(Trainer):
def fake_train(self):
self.fake_trained = []
def train(self):
for rec in self.fake_trained:
pass

View File

@@ -1,16 +1,29 @@
from copy import deepcopy
from operator import index
import pandas as pd
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.ens.group import RollingGroup
from qlib.utils.serial import Serializable
from typing import Dict, List, Union
from qlib import get_module_logger
from qlib.data.data import D
from qlib.model.trainer import Trainer, TrainerR, task_train
from qlib.workflow import R
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector
from qlib.workflow.task.collect import Collector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
"""
This class is a component of online serving, it can manage a series of models dynamically.
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
So this module provide a series methods to control this process.
"""
class OnlineManager:
class OnlineManager(Serializable):
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
@@ -18,26 +31,28 @@ class OnlineManager:
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, trainer: Trainer = None, collector: Collector = None, need_log=True):
SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment.
def __init__(self, trainer: Trainer = None, need_log=True):
"""
init OnlineManager.
Args:
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.delay_signals = {}
self.collector = collector
self.cur_time = None
def prepare_signals(self, *args, **kwargs):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
Must use `pass` even though there is nothing to do.
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
@@ -47,7 +62,7 @@ class OnlineManager:
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG):
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None):
"""
Use trainer to train a list of tasks and set the trained model to `tag`.
@@ -57,14 +72,20 @@ class OnlineManager:
`ONLINE_TAG` for first train or additional train
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
`OFFLINE_TAG` for train but offline those models
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
"""
# TODO: 回调
if not (tasks is None or len(tasks) == 0):
if check_func is None:
check_func = lambda x: True
if len(tasks) > 0:
if self.trainer is not None:
new_models = self.trainer.train(tasks)
self.set_online_tag(tag, new_models)
if self.need_log:
self.logger.info(f"Finished prepare {len(new_models)} new models and set them to {tag}.")
if check_func(new_models):
self.set_online_tag(tag, new_models)
if self.need_log:
self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.")
else:
self.logger.warn("No trainer to train new tasks.")
@@ -101,6 +122,12 @@ class OnlineManager:
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def first_train(self):
"""
Train a series of models firstly and set some of them into online models.
"""
raise NotImplementedError(f"Please implement the `first_train` method.")
def get_collector(self):
"""
Return the collector.
@@ -108,7 +135,7 @@ class OnlineManager:
Returns:
Collector
"""
return self.collector
raise NotImplementedError(f"Please implement the `get_collector` method.")
def run_delay_signals(self):
"""
@@ -122,9 +149,10 @@ class OnlineManager:
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
"""
The typical update process after a routine, such as day by day or month by month.
Prepare signals -> prepare tasks -> prepare new models -> update online prediction -> reset online models
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
"""
self.cur_time = cur_time # None for latest date
self.update_online_pred()
if not delay_prepare:
self.prepare_signals(*args, **kwargs)
else:
@@ -134,7 +162,7 @@ class OnlineManager:
raise ValueError("Can not delay prepare when cur_time is None")
tasks = self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(tasks)
self.update_online_pred()
return self.reset_online_tag()
@@ -144,19 +172,18 @@ class OnlineManagerR(OnlineManager):
"""
def __init__(self, experiment_name: str, trainer: Trainer = None, collector: Collector = None, need_log=True):
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
"""
init OnlineManagerR.
Args:
experiment_name (str): the experiment name.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
super().__init__(trainer=trainer, need_log=need_log)
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
@@ -212,7 +239,40 @@ class OnlineManagerR(OnlineManager):
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
if self.need_log:
self.logger.info(f"Finish updating {len(online_models)} online model predictions of {self.exp_name}.")
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
def prepare_signals(self, over_write=False):
"""
Average the predictions of online models and offer a trading signals every routine.
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
Args:
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
"""
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
recorder = R.get_recorder()
pred = []
try:
old_signals = recorder.load_object("signals")
except OSError:
old_signals = None
for rec in self.online_models():
pred.append(rec.load_object("pred.pkl"))
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
signals = signals.sort_index()
if old_signals is not None and not over_write:
# signals = old_signals.reindex(signals.index).combine_first(signals)
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
recorder.save_objects(**{"signals": signals})
class RollingOnlineManager(OnlineManagerR):
@@ -223,7 +283,6 @@ class RollingOnlineManager(OnlineManagerR):
experiment_name: str,
rolling_gen: RollingGen,
trainer: Trainer = None,
collector: Collector = None,
need_log=True,
):
"""
@@ -238,24 +297,64 @@ class RollingOnlineManager(OnlineManagerR):
"""
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.logger = get_module_logger(self.__class__.__name__)
def prepare_signals(self, *args, **kwargs):
def get_collector(self, rec_key_func=None, rec_filter_func=None):
"""
Average the online models prediction and save them into a recorder
get the instance of collector to collect results
Must use `pass` even though there is nothing to do.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
"""
# 检查recorder是否存在如果不存在就创建一个
# 检查recorder的上一个信号时间如果没有那就从上线模型的共同最早时间开始出信号
# 从recorder的上一个信号时间开始出信号出到self.cur_time
for model in self.online_models():
pass
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
"""
collecting artifact based on the collector and RollingGroup.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
Returns:
dict: the artifact dict after rolling ensemble
"""
artifact = ens_workflow(
self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup()
)
return artifact
def first_train(self, task_configs: list):
"""
Use rolling_gen to generate different tasks based on task_configs and trained them.
Args:
task_configs (list or dict): a list of task configs or a task config
Returns:
Collector: a instance of a Collector.
"""
tasks = task_generator(
tasks=task_configs,
generators=self.rg, # generate different date segment
)
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
self.prepare_signals(over_write=True)
return self.get_collector()
def prepare_tasks(self, *args, **kwargs):
"""
@@ -264,7 +363,6 @@ class RollingOnlineManager(OnlineManagerR):
Returns:
list: a list of new tasks.
"""
#TODO: max_test = self.cur_time
latest_records, max_test = self.list_latest_recorders(
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
)

View File

@@ -49,7 +49,7 @@ class RecorderCollector(Collector):
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
artifacts_key = list(self.artifacts_path.keys())
self._rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self._rec_filter_func = rec_filter_func

View File

@@ -194,6 +194,15 @@ class RollingGen(TaskGen):
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
# if end_time < the end of test_segments, then change end_time to allow load more data
if (
self.ta.cal_interval(
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
t["dataset"]["kwargs"]["segments"][self.test_key][1],
)
< 0
):
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
prev_seg = segments
res.append(t)
return res