mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
online serving V9 middle status
This commit is contained in:
@@ -1,20 +1,23 @@
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.online.simulator import OnlineSimulator
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
|
||||
The OnlineManager will focus on the updating of your online models.
|
||||
The OnlineSimulator will focus on the simulating real updating routine of your online models.
|
||||
"""
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineM # RollingOnlineManager
|
||||
from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
@@ -105,6 +108,8 @@ class OnlineSimulationExample:
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
@@ -115,17 +120,18 @@ class OnlineSimulationExample:
|
||||
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=self.rolling_gen,
|
||||
trainer=self.trainer,
|
||||
self.rolling_online_manager = OnlineM(
|
||||
RollingAverageStrategy(
|
||||
exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
|
||||
),
|
||||
begin_time=self.start_time,
|
||||
need_log=False,
|
||||
) # The OnlineManager based on Rolling
|
||||
self.onlinesimulator = OnlineSimulator(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
online_manager=self.rolling_online_manager,
|
||||
)
|
||||
# self.onlinesimulator = OnlineSimulator(
|
||||
# start_time=start_time,
|
||||
# end_time=end_time,
|
||||
# online_manager=self.rolling_online_manager,
|
||||
# )
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
@@ -137,37 +143,16 @@ class OnlineSimulationExample:
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == self.exp_name else False):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this firstly to see the workflow in OnlineManager
|
||||
def first_train(self):
|
||||
print("========== first train ==========")
|
||||
self.reset()
|
||||
self.rolling_online_manager.first_train(self.tasks)
|
||||
|
||||
# Run this secondly to see the simulating in OnlineSimulator
|
||||
def simulate(self):
|
||||
print("========== simulate ==========")
|
||||
self.onlinesimulator.simulate()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
|
||||
print("========== online models ==========")
|
||||
recs_dict = self.onlinesimulator.online_models()
|
||||
for time, recs in recs_dict.items():
|
||||
print(f"{str(time[0])} to {str(time[1])}:")
|
||||
for rec in recs:
|
||||
print(rec.info["id"])
|
||||
|
||||
print("========== online signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
# Run this to run all workflow automaticly
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.simulate()
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print(self.rolling_online_manager.get_online_history(self.exp_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
"""
|
||||
This example show how RollingOnlineManager works with rolling tasks.
|
||||
There are two parts including first train and routine.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import OnlineM
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from pprint import pprint
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -77,58 +78,65 @@ task_xgboost_config = {
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
exp_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
trainer=TrainerRM(self.exp_name, self.task_pool),
|
||||
)
|
||||
self.tasks = tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategy = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
|
||||
strategy.append(
|
||||
RollingAverageStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
TrainerRM(experiment_name=name_id, task_pool=name_id),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineM(strategy)
|
||||
|
||||
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
for task in self.tasks:
|
||||
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
|
||||
TaskManager(name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
exp.delete_recorder(rid)
|
||||
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
|
||||
self.rolling_online_manager.first_train()
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
|
||||
self.rolling_online_manager = pickle.load(f)
|
||||
self.rolling_online_manager.routine()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
"""
|
||||
This example show how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained model to `online` model.
|
||||
Next, we will finish updating online prediction.
|
||||
"""
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.manager import OnlineManagerR
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
"""
|
||||
This example show how OnlineManager works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
|
||||
Next, the RollingOnlineManager will finish updating online prediction
|
||||
"""
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -65,15 +63,15 @@ class UpdatePredExample:
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_manager = OnlineManagerR(self.experiment_name)
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_manager.reset_online_tag(rec) # set to online model
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_manager.update_online_pred()
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
|
||||
@@ -25,6 +25,7 @@ def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -
|
||||
Returns:
|
||||
Recorder
|
||||
"""
|
||||
# FIXME: recorder_id
|
||||
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
@@ -112,6 +113,9 @@ class Trainer:
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_delay(self):
|
||||
return False
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""Trainer based on (R)ecorder.
|
||||
@@ -240,6 +244,9 @@ class DelayTrainerR(TrainerR):
|
||||
end_train_func(rec)
|
||||
return recs
|
||||
|
||||
def is_delay(self):
|
||||
return True
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -286,3 +293,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
return recs
|
||||
|
||||
def is_delay(self):
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This class is a component of online serving, it can manage a series of models dynamically.
|
||||
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
|
||||
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
|
||||
So this module provide a series methods to control this process.
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from operator import index
|
||||
from pprint import pprint
|
||||
import pandas as pd
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
@@ -9,20 +18,13 @@ from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.trainer import Trainer, TrainerR, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
|
||||
"""
|
||||
This class is a component of online serving, it can manage a series of models dynamically.
|
||||
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
|
||||
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
|
||||
So this module provide a series methods to control this process.
|
||||
"""
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
@@ -357,9 +359,9 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
rolling_gen (RollingGen): a instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
trainer (Trainer, optional): an instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): an instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
if trainer is None:
|
||||
@@ -475,3 +477,98 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec, max_test
|
||||
|
||||
|
||||
class OnlineM(Serializable):
|
||||
def __init__(
|
||||
self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], begin_time=None, freq="day", need_log=True
|
||||
):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.strategy = strategy
|
||||
self.freq = freq
|
||||
if begin_time is None:
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(begin_time)
|
||||
self.history = {}
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Train a series of models firstly and set some of them into online models.
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
online_models = strategy.first_train()
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
|
||||
def routine(self, cur_time=None, task_kwargs={}, model_kwargs={}):
|
||||
"""
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
|
||||
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
|
||||
|
||||
Args:
|
||||
cur_time ([type], optional): [description]. Defaults to None.
|
||||
delay_prepare (bool, optional): [description]. Defaults to False.
|
||||
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
if cur_time is None:
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
for strategy in self.strategy:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if not strategy.trainer.is_delay():
|
||||
strategy.prepare_signals()
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
online_models = strategy.prepare_online_models(tasks, **model_kwargs)
|
||||
if len(online_models) > 0:
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
|
||||
def get_collector(self):
|
||||
collector_dict = {}
|
||||
for strategy in self.strategy:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return HyperCollector(collector_dict)
|
||||
|
||||
def get_online_history(self, strategy_name_id):
|
||||
history_dict = self.history[strategy_name_id]
|
||||
history = []
|
||||
for time in sorted(history_dict):
|
||||
models = history_dict[time]
|
||||
history.append((time, models))
|
||||
return history
|
||||
|
||||
def delay_prepare(self, delay_kwargs={}):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs)
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}):
|
||||
"""
|
||||
Starting from start time, this method will simulate every routine in OnlineManager.
|
||||
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
|
||||
Returns:
|
||||
Collector: the OnlineManager's collector
|
||||
"""
|
||||
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
|
||||
self.first_train()
|
||||
for cur_time in cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs)
|
||||
self.delay_prepare(delay_kwargs=delay_kwargs)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.get_collector()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from qlib.data import D
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.manager import OnlineM
|
||||
|
||||
|
||||
class OnlineSimulator:
|
||||
@@ -32,7 +32,35 @@ class OnlineSimulator:
|
||||
if len(self.cal) == 0:
|
||||
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
|
||||
|
||||
def simulate(self, *args, **kwargs):
|
||||
# def simulate(self, *args, **kwargs):
|
||||
# """
|
||||
# Starting from start time, this method will simulate every routine in OnlineManager.
|
||||
# NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
|
||||
# Returns:
|
||||
# Collector: the OnlineManager's collector
|
||||
# """
|
||||
# self.rec_dict = {}
|
||||
# tmp_begin = self.start_time
|
||||
# tmp_end = None
|
||||
# self.olm.first_train()
|
||||
# prev_recorders = self.olm.online_models()
|
||||
# for cur_time in self.cal:
|
||||
# self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
# recorders = self.olm.routine(cur_time, True, *args, **kwargs)
|
||||
# if len(recorders) == 0:
|
||||
# tmp_end = cur_time
|
||||
# else:
|
||||
# self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
|
||||
# tmp_begin = cur_time
|
||||
# prev_recorders = recorders
|
||||
# self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
|
||||
# # finished perparing models (and pred) and signals
|
||||
# self.olm.delay_prepare(self.rec_dict)
|
||||
# self.logger.info(f"Finished preparing signals")
|
||||
# return self.olm.get_collector()
|
||||
|
||||
def simulate(self, task_kwargs={}, model_kwargs={}):
|
||||
"""
|
||||
Starting from start time, this method will simulate every routine in OnlineManager.
|
||||
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
@@ -40,33 +68,10 @@ class OnlineSimulator:
|
||||
Returns:
|
||||
Collector: the OnlineManager's collector
|
||||
"""
|
||||
self.rec_dict = {}
|
||||
tmp_begin = self.start_time
|
||||
tmp_end = None
|
||||
prev_recorders = self.olm.online_models()
|
||||
self.olm.first_train()
|
||||
for cur_time in self.cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
|
||||
if len(recorders) == 0:
|
||||
tmp_end = cur_time
|
||||
else:
|
||||
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
|
||||
tmp_begin = cur_time
|
||||
prev_recorders = recorders
|
||||
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
|
||||
# finished perparing models (and pred) and signals
|
||||
self.olm.delay_prepare(self.rec_dict)
|
||||
self.olm.routine(cur_time, task_kwargs={}, model_kwargs={})
|
||||
self.olm.delay_prepare()
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.olm.get_collector()
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return a online models dict likes {(begin_time, end_time):[online models]}.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
if hasattr(self, "rec_dict"):
|
||||
return self.rec_dict
|
||||
self.logger.warn(f"Please call `simulate` firstly when calling `online_models`")
|
||||
return {}
|
||||
|
||||
293
qlib/workflow/online/strategy.py
Normal file
293
qlib/workflow/online/strategy.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This module is working with OnlineManager, responsing for a set of strategy about how the models are updated and signals are perpared.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.task.collect import HyperCollector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
|
||||
|
||||
class OnlineStrategy:
|
||||
def __init__(self, name_id: str, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
init OnlineManager.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.tool = OnlineTool()
|
||||
self.history = {}
|
||||
|
||||
def prepare_signals(self, delay=False):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
return the new tasks waiting for training.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, tasks, check_func=None, **kwargs):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `online`.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
tag (str):
|
||||
`ONLINE_TAG` for first train or additional train
|
||||
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
|
||||
`OFFLINE_TAG` for train but offline those models
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
"""
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
online_models = []
|
||||
if len(tasks) > 0:
|
||||
new_models = self.trainer.train(tasks, **kwargs)
|
||||
for model in new_models:
|
||||
if check_func(model):
|
||||
online_models.append(model)
|
||||
self.tool.reset_online_tag(online_models)
|
||||
return online_models
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Train a series of models firstly and set some of them into online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_train` method.")
|
||||
|
||||
def get_collector(self):
|
||||
"""
|
||||
Return the collector.
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def delay_prepare(self, history, **kwargs):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for time_begin, recs_list in history:
|
||||
self.trainer.end_train(recs_list, **kwargs)
|
||||
self.tool.reset_online_tag(recs_list)
|
||||
self.prepare_signals(delay=True)
|
||||
|
||||
|
||||
class RollingAverageStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
need_log=True,
|
||||
signal_exp_name="OnlineManagerSignals",
|
||||
):
|
||||
"""
|
||||
init OnlineManagerR.
|
||||
|
||||
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id. Will be also the name of Experiment.
|
||||
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
signal_exp_path (str): a specific experiment to save signals of different experiment.
|
||||
"""
|
||||
super().__init__(name_id=name_id, trainer=trainer, need_log=need_log)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.signal_rec = None
|
||||
self.signal_exp_name = signal_exp_name
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
|
||||
def get_collector(self, rec_key_func=None, rec_filter_func=None):
|
||||
"""
|
||||
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
|
||||
Assumption: the models can be distinguished based on model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your own method or use another rec_key_func.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
"""
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
artifacts_collector = RecorderCollector(
|
||||
experiment=self.exp_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key_func,
|
||||
rec_filter_func=rec_filter_func,
|
||||
)
|
||||
|
||||
signals_collector = RecorderCollector(
|
||||
experiment=self.signal_exp_name,
|
||||
rec_key_func=lambda rec: rec.info["name"],
|
||||
rec_filter_func=lambda rec: rec.info["name"] == self.exp_name,
|
||||
artifacts_path={"signals": "signals"},
|
||||
)
|
||||
return HyperCollector({"artifacts": artifacts_collector, "signals": signals_collector})
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_template and trained them.
|
||||
|
||||
Returns:
|
||||
Collector: a instance of a Collector.
|
||||
"""
|
||||
tasks = task_generator(
|
||||
tasks=self.task_template,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
return self.prepare_online_models(tasks)
|
||||
|
||||
def prepare_tasks(self, cur_time):
|
||||
"""
|
||||
Prepare new tasks based on cur_time (None for latest).
|
||||
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
"""
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def prepare_signals(self, delay=False, over_write=False):
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
Returns:
|
||||
object: the signals.
|
||||
"""
|
||||
if not delay:
|
||||
self.tool.update_online_pred()
|
||||
if self.signal_rec is None:
|
||||
with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder()
|
||||
|
||||
pred = []
|
||||
try:
|
||||
old_signals = self.signal_rec.load_object("signals")
|
||||
except OSError:
|
||||
old_signals = None
|
||||
|
||||
for rec in self.tool.online_models():
|
||||
pred.append(rec.load_object("pred.pkl"))
|
||||
|
||||
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
|
||||
signals = signals.sort_index()
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}."
|
||||
)
|
||||
self.signal_rec.save_objects(**{"signals": signals})
|
||||
return signals
|
||||
|
||||
# def get_signals(self):
|
||||
# """
|
||||
# get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
|
||||
|
||||
# Returns:
|
||||
# signals
|
||||
# """
|
||||
# if self.signal_rec is None:
|
||||
# with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
|
||||
# self.signal_rec = R.get_recorder()
|
||||
# signals = None
|
||||
# try:
|
||||
# signals = self.signal_rec.load_object("signals")
|
||||
# except OSError:
|
||||
# self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
|
||||
# return signals
|
||||
|
||||
def _list_latest(self, rec_list):
|
||||
if len(rec_list) == 0:
|
||||
return rec_list, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
|
||||
latest_rec = []
|
||||
for rec in rec_list:
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec.append(rec)
|
||||
return latest_rec, max_test
|
||||
165
qlib/workflow/online/utils.py
Normal file
165
qlib/workflow/online/utils.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
This module is like a online backend, deciding which models are `online` models and how can change them
|
||||
"""
|
||||
from typing import List, Union
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class OnlineTool:
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, need_log=True):
|
||||
"""
|
||||
init OnlineTool.
|
||||
|
||||
Args:
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.cur_time = None
|
||||
|
||||
def set_online_tag(self, tag, recorder):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, recorders=None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (List, optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
list: new online recorder. [] if there is no update.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return `online` models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of online models to a date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for latest.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
|
||||
class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, need_log=True):
|
||||
"""
|
||||
init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
super().__init__(need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List])
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
if self.need_log:
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): a instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (Union[Recorder, List], optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
list: new online recorder. [] if there is no update.
|
||||
"""
|
||||
if recorder is None:
|
||||
recorder = list(
|
||||
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.NEXT_ONLINE_TAG).values()
|
||||
)
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
if len(recorder) == 0:
|
||||
if self.need_log:
|
||||
self.logger.info("No 'next online' model, just use current 'online' models.")
|
||||
return []
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
return recorder
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return online models.
|
||||
|
||||
Returns:
|
||||
list: the list of online models
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of online models to a date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for latest in Calendar.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
from qlib import init
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
@@ -109,6 +110,27 @@ class Collector:
|
||||
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
|
||||
|
||||
|
||||
class HyperCollector(Collector):
|
||||
"""
|
||||
A collector to collect the results of other Collectors
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict, process_list=[]):
|
||||
"""
|
||||
Args:
|
||||
collector_dict (dict): the dict like {collector_key, Collector}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
|
||||
def collect(self):
|
||||
collect_dict = {}
|
||||
for key, collector in self.collector_dict.items():
|
||||
collect_dict[key] = collector()
|
||||
return collect_dict
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
@@ -180,3 +202,6 @@ class RecorderCollector(Collector):
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
def get_exp_name(self):
|
||||
return self.experiment.name
|
||||
|
||||
Reference in New Issue
Block a user