1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

online serving V9 middle status

This commit is contained in:
lzh222333
2021-04-28 09:23:07 +00:00
parent 42f510024c
commit 40cf83e557
9 changed files with 721 additions and 135 deletions

View File

@@ -1,20 +1,23 @@
import fire
import qlib
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.online.simulator import OnlineSimulator
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.utils import list_recorders
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
The OnlineManager will focus on the updating of your online models.
The OnlineSimulator will focus on the simulating real updating routine of your online models.
"""
import fire
import qlib
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineM # RollingOnlineManager
from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.utils import list_recorders
data_handler_config = {
@@ -105,6 +108,8 @@ class OnlineSimulationExample:
"""
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
@@ -115,17 +120,18 @@ class OnlineSimulationExample:
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=self.rolling_gen,
trainer=self.trainer,
self.rolling_online_manager = OnlineM(
RollingAverageStrategy(
exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
),
begin_time=self.start_time,
need_log=False,
) # The OnlineManager based on Rolling
self.onlinesimulator = OnlineSimulator(
start_time=start_time,
end_time=end_time,
online_manager=self.rolling_online_manager,
)
# self.onlinesimulator = OnlineSimulator(
# start_time=start_time,
# end_time=end_time,
# online_manager=self.rolling_online_manager,
# )
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
@@ -137,37 +143,16 @@ class OnlineSimulationExample:
for rid in exp.list_recorders():
exp.delete_recorder(rid)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == self.exp_name else False):
exp.delete_recorder(rid)
# Run this firstly to see the workflow in OnlineManager
def first_train(self):
print("========== first train ==========")
self.reset()
self.rolling_online_manager.first_train(self.tasks)
# Run this secondly to see the simulating in OnlineSimulator
def simulate(self):
print("========== simulate ==========")
self.onlinesimulator.simulate()
print(self.rolling_online_manager.collect_artifact())
print("========== online models ==========")
recs_dict = self.onlinesimulator.online_models()
for time, recs in recs_dict.items():
print(f"{str(time[0])} to {str(time[1])}:")
for rec in recs:
print(rec.info["id"])
print("========== online signals ==========")
print(self.rolling_online_manager.get_signals())
# Run this to run all workflow automaticly
def main(self):
self.first_train()
self.simulate()
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print(self.rolling_online_manager.get_collector()())
print(self.rolling_online_manager.get_online_history(self.exp_name))
if __name__ == "__main__":

View File

@@ -1,21 +1,22 @@
import os
from pathlib import Path
import pickle
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.task.utils import list_recorders
from qlib.model.trainer import TrainerRM
"""
This example show how RollingOnlineManager works with rolling tasks.
There are two parts including first train and routine.
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
"""
import os
from pathlib import Path
import pickle
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import OnlineM
from qlib.workflow.task.utils import list_recorders
from qlib.model.trainer import TrainerRM
from pprint import pprint
data_handler_config = {
"start_time": "2013-01-01",
@@ -77,58 +78,65 @@ task_xgboost_config = {
class RollingOnlineExample:
def __init__(
self,
exp_name="rolling_exp",
task_pool="rolling_task",
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=[task_xgboost_config, task_lgb_config],
):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
trainer=TrainerRM(self.exp_name, self.task_pool),
)
self.tasks = tasks
self.rolling_step = rolling_step
strategy = []
for task in tasks:
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
strategy.append(
RollingAverageStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
TrainerRM(experiment_name=name_id, task_pool=name_id),
)
)
self.rolling_online_manager = OnlineM(strategy)
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
for task in self.tasks:
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
TaskManager(name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
exp.delete_recorder(rid)
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False):
exp.delete_recorder(rid)
def first_run(self):
print("========== first_run ==========")
self.reset()
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
self.rolling_online_manager.first_train()
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
print(self.rolling_online_manager.collect_artifact())
print(self.rolling_online_manager.get_collector()())
def routine(self):
print("========== routine ==========")
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
self.rolling_online_manager = pickle.load(f)
self.rolling_online_manager.routine()
print(self.rolling_online_manager.collect_artifact())
print(self.rolling_online_manager.get_collector()())
def main(self):
self.first_run()

View File

@@ -1,16 +1,14 @@
"""
This example show how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained model to `online` model.
Next, we will finish updating online prediction.
"""
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.manager import OnlineManagerR
from qlib.workflow.task.utils import list_recorders
"""
This example show how OnlineManager works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
Next, the RollingOnlineManager will finish updating online prediction
"""
from qlib.workflow.online.utils import OnlineToolR
data_handler_config = {
"start_time": "2008-01-01",
@@ -65,15 +63,15 @@ class UpdatePredExample:
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_manager = OnlineManagerR(self.experiment_name)
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_manager.reset_online_tag(rec) # set to online model
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_manager.update_online_pred()
self.online_tool.update_online_pred()
def main(self):
self.first_train()

View File

@@ -25,6 +25,7 @@ def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -
Returns:
Recorder
"""
# FIXME: recorder_id
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
@@ -112,6 +113,9 @@ class Trainer:
"""
pass
def is_delay(self):
return False
class TrainerR(Trainer):
"""Trainer based on (R)ecorder.
@@ -240,6 +244,9 @@ class DelayTrainerR(TrainerR):
end_train_func(rec)
return recs
def is_delay(self):
return True
class DelayTrainerRM(TrainerRM):
"""
@@ -286,3 +293,6 @@ class DelayTrainerRM(TrainerRM):
before_status=TaskManager.STATUS_PART_DONE,
)
return recs
def is_delay(self):
return True

View File

@@ -1,5 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This class is a component of online serving, it can manage a series of models dynamically.
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
So this module provide a series methods to control this process.
"""
from copy import deepcopy
from operator import index
from pprint import pprint
import pandas as pd
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.ens.group import RollingGroup
@@ -9,20 +18,13 @@ from qlib import get_module_logger
from qlib.data.data import D
from qlib.model.trainer import Trainer, TrainerR, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import OnlineStrategy
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
"""
This class is a component of online serving, it can manage a series of models dynamically.
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
So this module provide a series methods to control this process.
"""
class OnlineManager(Serializable):
ONLINE_KEY = "online_status" # the online status key in recorder
@@ -357,9 +359,9 @@ class RollingOnlineManager(OnlineManagerR):
Args:
experiment_name (str): the experiment name.
rolling_gen (RollingGen): a instance of RollingGen
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
rolling_gen (RollingGen): an instance of RollingGen
trainer (Trainer, optional): an instance of Trainer. Defaults to None.
collector (Collector, optional): an instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
if trainer is None:
@@ -475,3 +477,98 @@ class RollingOnlineManager(OnlineManagerR):
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec, max_test
class OnlineM(Serializable):
def __init__(
self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], begin_time=None, freq="day", need_log=True
):
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
if not isinstance(strategy, list):
strategy = [strategy]
self.strategy = strategy
self.freq = freq
if begin_time is None:
begin_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(begin_time)
self.history = {}
def first_train(self):
"""
Train a series of models firstly and set some of them into online models.
"""
for strategy in self.strategy:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
online_models = strategy.first_train()
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
def routine(self, cur_time=None, task_kwargs={}, model_kwargs={}):
"""
The typical update process after a routine, such as day by day or month by month.
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
Args:
cur_time ([type], optional): [description]. Defaults to None.
delay_prepare (bool, optional): [description]. Defaults to False.
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
Returns:
[type]: [description]
"""
if cur_time is None:
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
for strategy in self.strategy:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if not strategy.trainer.is_delay():
strategy.prepare_signals()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
online_models = strategy.prepare_online_models(tasks, **model_kwargs)
if len(online_models) > 0:
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
def get_collector(self):
collector_dict = {}
for strategy in self.strategy:
collector_dict[strategy.name_id] = strategy.get_collector()
return HyperCollector(collector_dict)
def get_online_history(self, strategy_name_id):
history_dict = self.history[strategy_name_id]
history = []
for time in sorted(history_dict):
models = history_dict[time]
history.append((time, models))
return history
def delay_prepare(self, delay_kwargs={}):
"""
Prepare all models and signals if there are something waiting for prepare.
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
Args:
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
for strategy in self.strategy:
strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs)
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}):
"""
Starting from start time, this method will simulate every routine in OnlineManager.
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
Returns:
Collector: the OnlineManager's collector
"""
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
for cur_time in cal:
self.logger.info(f"Simulating at {str(cur_time)}......")
self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs)
self.delay_prepare(delay_kwargs=delay_kwargs)
self.logger.info(f"Finished preparing signals")
return self.get_collector()

View File

@@ -1,6 +1,6 @@
from qlib.data import D
from qlib import get_module_logger
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.manager import OnlineM
class OnlineSimulator:
@@ -32,7 +32,35 @@ class OnlineSimulator:
if len(self.cal) == 0:
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
def simulate(self, *args, **kwargs):
# def simulate(self, *args, **kwargs):
# """
# Starting from start time, this method will simulate every routine in OnlineManager.
# NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
# Returns:
# Collector: the OnlineManager's collector
# """
# self.rec_dict = {}
# tmp_begin = self.start_time
# tmp_end = None
# self.olm.first_train()
# prev_recorders = self.olm.online_models()
# for cur_time in self.cal:
# self.logger.info(f"Simulating at {str(cur_time)}......")
# recorders = self.olm.routine(cur_time, True, *args, **kwargs)
# if len(recorders) == 0:
# tmp_end = cur_time
# else:
# self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
# tmp_begin = cur_time
# prev_recorders = recorders
# self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
# # finished perparing models (and pred) and signals
# self.olm.delay_prepare(self.rec_dict)
# self.logger.info(f"Finished preparing signals")
# return self.olm.get_collector()
def simulate(self, task_kwargs={}, model_kwargs={}):
"""
Starting from start time, this method will simulate every routine in OnlineManager.
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
@@ -40,33 +68,10 @@ class OnlineSimulator:
Returns:
Collector: the OnlineManager's collector
"""
self.rec_dict = {}
tmp_begin = self.start_time
tmp_end = None
prev_recorders = self.olm.online_models()
self.olm.first_train()
for cur_time in self.cal:
self.logger.info(f"Simulating at {str(cur_time)}......")
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
if len(recorders) == 0:
tmp_end = cur_time
else:
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
tmp_begin = cur_time
prev_recorders = recorders
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
# finished perparing models (and pred) and signals
self.olm.delay_prepare(self.rec_dict)
self.olm.routine(cur_time, task_kwargs={}, model_kwargs={})
self.olm.delay_prepare()
self.logger.info(f"Finished preparing signals")
return self.olm.get_collector()
def online_models(self):
"""
Return a online models dict likes {(begin_time, end_time):[online models]}.
Returns:
dict
"""
if hasattr(self, "rec_dict"):
return self.rec_dict
self.logger.warn(f"Please call `simulate` firstly when calling `online_models`")
return {}

View File

@@ -0,0 +1,293 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module is working with OnlineManager, responsing for a set of strategy about how the models are updated and signals are perpared.
"""
from copy import deepcopy
from typing import List, Union
import pandas as pd
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import Trainer, TrainerR
from qlib.workflow import R
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.task.collect import HyperCollector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
class OnlineStrategy:
def __init__(self, name_id: str, trainer: Trainer = None, need_log=True):
"""
init OnlineManager.
Args:
name_id (str): a unique name or id
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.name_id = name_id
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.tool = OnlineTool()
self.history = {}
def prepare_signals(self, delay=False):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
Must use `pass` even though there is nothing to do.
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""
After the end of a routine, check whether we need to prepare and train some new tasks.
return the new tasks waiting for training.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, tasks, check_func=None, **kwargs):
"""
Use trainer to train a list of tasks and set the trained model to `online`.
Args:
tasks (list): a list of tasks.
tag (str):
`ONLINE_TAG` for first train or additional train
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
`OFFLINE_TAG` for train but offline those models
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
**kwargs: will be passed to end_train which means will be passed to customized train method.
"""
if check_func is None:
check_func = lambda x: True
online_models = []
if len(tasks) > 0:
new_models = self.trainer.train(tasks, **kwargs)
for model in new_models:
if check_func(model):
online_models.append(model)
self.tool.reset_online_tag(online_models)
return online_models
def first_train(self):
"""
Train a series of models firstly and set some of them into online models.
"""
raise NotImplementedError(f"Please implement the `first_train` method.")
def get_collector(self):
"""
Return the collector.
Returns:
Collector
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
def delay_prepare(self, history, **kwargs):
"""
Prepare all models and signals if there are something waiting for prepare.
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
Args:
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
for time_begin, recs_list in history:
self.trainer.end_train(recs_list, **kwargs)
self.tool.reset_online_tag(recs_list)
self.prepare_signals(delay=True)
class RollingAverageStrategy(OnlineStrategy):
"""
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
"""
def __init__(
self,
name_id: str,
task_template: Union[dict, List[dict]],
rolling_gen: RollingGen,
trainer: Trainer = None,
need_log=True,
signal_exp_name="OnlineManagerSignals",
):
"""
init OnlineManagerR.
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
Args:
name_id (str): a unique name or id. Will be also the name of Experiment.
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
signal_exp_path (str): a specific experiment to save signals of different experiment.
"""
super().__init__(name_id=name_id, trainer=trainer, need_log=need_log)
self.exp_name = self.name_id
if not isinstance(task_template, list):
task_template = [task_template]
self.task_template = task_template
self.signal_rec = None
self.signal_exp_name = signal_exp_name
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.tool = OnlineToolR(self.exp_name)
def get_collector(self, rec_key_func=None, rec_filter_func=None):
"""
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
Assumption: the models can be distinguished based on model name and rolling test segments.
If you do not want this assumption, please implement your own method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
"""
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
artifacts_collector = RecorderCollector(
experiment=self.exp_name,
process_list=RollingGroup(),
rec_key_func=rec_key_func,
rec_filter_func=rec_filter_func,
)
signals_collector = RecorderCollector(
experiment=self.signal_exp_name,
rec_key_func=lambda rec: rec.info["name"],
rec_filter_func=lambda rec: rec.info["name"] == self.exp_name,
artifacts_path={"signals": "signals"},
)
return HyperCollector({"artifacts": artifacts_collector, "signals": signals_collector})
def first_train(self):
"""
Use rolling_gen to generate different tasks based on task_template and trained them.
Returns:
Collector: a instance of a Collector.
"""
tasks = task_generator(
tasks=self.task_template,
generators=self.rg, # generate different date segment
)
return self.prepare_online_models(tasks)
def prepare_tasks(self, cur_time):
"""
Prepare new tasks based on cur_time (None for latest).
Returns:
list: a list of new tasks.
"""
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
if self.need_log:
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
def prepare_signals(self, delay=False, over_write=False):
"""
Average the predictions of online models and offer a trading signals every routine.
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
Even if the latest signal already exists, the latest calculation result will be overwritten.
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
Returns:
object: the signals.
"""
if not delay:
self.tool.update_online_pred()
if self.signal_rec is None:
with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder()
pred = []
try:
old_signals = self.signal_rec.load_object("signals")
except OSError:
old_signals = None
for rec in self.tool.online_models():
pred.append(rec.load_object("pred.pkl"))
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
signals = signals.sort_index()
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
if self.need_log:
self.logger.info(
f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}."
)
self.signal_rec.save_objects(**{"signals": signals})
return signals
# def get_signals(self):
# """
# get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
# Returns:
# signals
# """
# if self.signal_rec is None:
# with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
# self.signal_rec = R.get_recorder()
# signals = None
# try:
# signals = self.signal_rec.load_object("signals")
# except OSError:
# self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
# return signals
def _list_latest(self, rec_list):
if len(rec_list) == 0:
return rec_list, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
latest_rec = []
for rec in rec_list:
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec.append(rec)
return latest_rec, max_test

View File

@@ -0,0 +1,165 @@
"""
This module is like a online backend, deciding which models are `online` models and how can change them
"""
from typing import List, Union
from qlib.log import get_module_logger
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
class OnlineTool:
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, need_log=True):
"""
init OnlineTool.
Args:
need_log (bool, optional): print log or not. Defaults to True.
"""
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.cur_time = None
def set_online_tag(self, tag, recorder):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self):
"""
Given a model and return its online tag.
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, recorders=None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (List, optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
list: new online recorder. [] if there is no update.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self):
"""
Return `online` models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def update_online_pred(self, to_date=None):
"""
Update the predictions of online models to a date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
class OnlineToolR(OnlineTool):
"""
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str, need_log=True):
"""
init OnlineToolR.
Args:
experiment_name (str): the experiment name.
need_log (bool, optional): print log or not. Defaults to True.
"""
super().__init__(need_log=need_log)
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List])
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
if self.need_log:
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder):
"""
Given a model and return its online tag.
Args:
recorder (Recorder): a instance of recorder
Returns:
str: the tag
"""
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (Union[Recorder, List], optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
list: new online recorder. [] if there is no update.
"""
if recorder is None:
recorder = list(
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.NEXT_ONLINE_TAG).values()
)
if isinstance(recorder, Recorder):
recorder = [recorder]
if len(recorder) == 0:
if self.need_log:
self.logger.info("No 'next online' model, just use current 'online' models.")
return []
recs = list_recorders(self.exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
return recorder
def online_models(self):
"""
Return online models.
Returns:
list: the list of online models
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
"""
Update the predictions of online models to a date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for latest in Calendar.
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
if self.need_log:
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Callable, Union
from qlib import init
from qlib.workflow import R
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
@@ -109,6 +110,27 @@ class Collector:
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
class HyperCollector(Collector):
"""
A collector to collect the results of other Collectors
"""
def __init__(self, collector_dict, process_list=[]):
"""
Args:
collector_dict (dict): the dict like {collector_key, Collector}
process_list (list or Callable): the list of processors or the instance of processor to process dict.
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
def collect(self):
collect_dict = {}
for key, collector in self.collector_dict.items():
collect_dict[key] = collector()
return collect_dict
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
@@ -180,3 +202,6 @@ class RecorderCollector(Collector):
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict
def get_exp_name(self):
return self.experiment.name