mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
online serving V8
This commit is contained in:
@@ -15,6 +15,11 @@ from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
"""
|
||||
This example shows how a Trainer work based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be showed in task_collecting.
|
||||
"""
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
@@ -71,81 +76,83 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(task_pool, exp_name):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=[task_xgboost_config, task_lgb_config],
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating():
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
print("========== task_generating ==========")
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
|
||||
)
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
pprint(tasks)
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
return tasks
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
def task_training(tasks, task_pool, exp_name):
|
||||
trainer = TrainerRM(exp_name, task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(exp_name):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
def main(
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
reset(task_pool, experiment_name)
|
||||
tasks = task_generating()
|
||||
task_training(tasks, task_pool, experiment_name)
|
||||
task_collecting(experiment_name)
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
|
||||
@@ -11,7 +11,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
"""
|
||||
This examples is about the OnlineManager and OnlineSimulator based on Rolling tasks.
|
||||
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
|
||||
The OnlineManager will focus on the updating of your online models.
|
||||
The OnlineSimulator will focus on the simulating real updating routine of your online models.
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from pprint import pprint
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
|
||||
"""
|
||||
This example show how RollingOnlineManager works with rolling tasks.
|
||||
There are two parts including first train and routine.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
"""
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -89,92 +92,38 @@ class RollingOnlineExample:
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool)
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
trainer=TrainerRM(self.exp_name, self.task_pool),
|
||||
)
|
||||
|
||||
def print_online_model(self):
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
|
||||
for rec in self.rolling_online_manager.online_models():
|
||||
print(rec.info["id"])
|
||||
print("Current 'next online' model:")
|
||||
for rid, rec in list_recorders(self.exp_name).items():
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(self):
|
||||
|
||||
print("========== task_generating ==========")
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=self.rolling_gen, # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
# self.trainer.train(tasks)
|
||||
self.rolling_online_manager.prepare_new_models(tasks, tag=RollingOnlineManager.ONLINE_TAG)
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
|
||||
)
|
||||
print(artifact)
|
||||
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = self.task_generating()
|
||||
pprint(tasks)
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
# latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
# self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
self.print_online_model()
|
||||
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
|
||||
self.rolling_online_manager = pickle.load(f)
|
||||
self.rolling_online_manager.routine()
|
||||
self.print_online_model()
|
||||
self.task_collecting()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
@@ -5,6 +5,13 @@ from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.manager import OnlineManagerR
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
"""
|
||||
This example show how OnlineManager works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
|
||||
Next, the RollingOnlineManager will finish updating online prediction
|
||||
"""
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
@@ -52,31 +59,25 @@ task = {
|
||||
}
|
||||
|
||||
|
||||
def first_train(experiment_name="online_srv"):
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_manager = OnlineManagerR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
rec = task_train(task_config=task, experiment_name=experiment_name)
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_manager.reset_online_tag(rec) # set to online model
|
||||
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
online_manager.reset_online_tag(rec)
|
||||
def update_online_pred(self):
|
||||
self.online_manager.update_online_pred()
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_srv"):
|
||||
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
|
||||
print("Here are the online models waiting for update:")
|
||||
for rid, rec in list_recorders(experiment_name).items():
|
||||
if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
online_manager.update_online_pred()
|
||||
|
||||
|
||||
def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
first_train(experiment_name)
|
||||
update_online_pred(experiment_name)
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -86,4 +87,4 @@ if __name__ == "__main__":
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
fire.Fire(UpdatePredExample)
|
||||
|
||||
@@ -135,3 +135,12 @@ class TrainerRM(Trainer):
|
||||
for _id in _id_list:
|
||||
recs.append(tm.re_query(_id)["res"])
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainer(Trainer):
|
||||
def fake_train(self):
|
||||
self.fake_trained = []
|
||||
|
||||
def train(self):
|
||||
for rec in self.fake_trained:
|
||||
pass
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
from copy import deepcopy
|
||||
from operator import index
|
||||
import pandas as pd
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.utils.serial import Serializable
|
||||
from typing import Dict, List, Union
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.trainer import Trainer, TrainerR, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
|
||||
"""
|
||||
This class is a component of online serving, it can manage a series of models dynamically.
|
||||
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
|
||||
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
|
||||
So this module provide a series methods to control this process.
|
||||
"""
|
||||
|
||||
class OnlineManager:
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
@@ -18,26 +31,28 @@ class OnlineManager:
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, trainer: Trainer = None, collector: Collector = None, need_log=True):
|
||||
SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment.
|
||||
|
||||
def __init__(self, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
init OnlineManager.
|
||||
|
||||
Args:
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.delay_signals = {}
|
||||
self.collector = collector
|
||||
self.cur_time = None
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
@@ -47,7 +62,7 @@ class OnlineManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG):
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `tag`.
|
||||
|
||||
@@ -57,14 +72,20 @@ class OnlineManager:
|
||||
`ONLINE_TAG` for first train or additional train
|
||||
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
|
||||
`OFFLINE_TAG` for train but offline those models
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
|
||||
"""
|
||||
# TODO: 回调
|
||||
if not (tasks is None or len(tasks) == 0):
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
if len(tasks) > 0:
|
||||
if self.trainer is not None:
|
||||
new_models = self.trainer.train(tasks)
|
||||
self.set_online_tag(tag, new_models)
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished prepare {len(new_models)} new models and set them to {tag}.")
|
||||
if check_func(new_models):
|
||||
self.set_online_tag(tag, new_models)
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.")
|
||||
else:
|
||||
self.logger.warn("No trainer to train new tasks.")
|
||||
|
||||
@@ -101,6 +122,12 @@ class OnlineManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Train a series of models firstly and set some of them into online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_train` method.")
|
||||
|
||||
def get_collector(self):
|
||||
"""
|
||||
Return the collector.
|
||||
@@ -108,7 +135,7 @@ class OnlineManager:
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
return self.collector
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def run_delay_signals(self):
|
||||
"""
|
||||
@@ -122,9 +149,10 @@ class OnlineManager:
|
||||
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
|
||||
"""
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
Prepare signals -> prepare tasks -> prepare new models -> update online prediction -> reset online models
|
||||
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
"""
|
||||
self.cur_time = cur_time # None for latest date
|
||||
self.update_online_pred()
|
||||
if not delay_prepare:
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
else:
|
||||
@@ -134,7 +162,7 @@ class OnlineManager:
|
||||
raise ValueError("Can not delay prepare when cur_time is None")
|
||||
tasks = self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(tasks)
|
||||
self.update_online_pred()
|
||||
|
||||
return self.reset_online_tag()
|
||||
|
||||
|
||||
@@ -144,19 +172,18 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None, collector: Collector = None, need_log=True):
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
init OnlineManagerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
|
||||
super().__init__(trainer=trainer, need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
@@ -212,7 +239,40 @@ class OnlineManagerR(OnlineManager):
|
||||
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finish updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
|
||||
def prepare_signals(self, over_write=False):
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
"""
|
||||
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
recorder = R.get_recorder()
|
||||
pred = []
|
||||
|
||||
try:
|
||||
old_signals = recorder.load_object("signals")
|
||||
except OSError:
|
||||
old_signals = None
|
||||
|
||||
for rec in self.online_models():
|
||||
pred.append(rec.load_object("pred.pkl"))
|
||||
|
||||
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
|
||||
signals = signals.sort_index()
|
||||
if old_signals is not None and not over_write:
|
||||
# signals = old_signals.reindex(signals.index).combine_first(signals)
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
|
||||
recorder.save_objects(**{"signals": signals})
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
@@ -223,7 +283,6 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
experiment_name: str,
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
collector: Collector = None,
|
||||
need_log=True,
|
||||
):
|
||||
"""
|
||||
@@ -238,24 +297,64 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
"""
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
|
||||
super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
def get_collector(self, rec_key_func=None, rec_filter_func=None):
|
||||
"""
|
||||
Average the online models prediction and save them into a recorder
|
||||
|
||||
get the instance of collector to collect results
|
||||
|
||||
Must use `pass` even though there is nothing to do.
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
"""
|
||||
# 检查recorder是否存在,如果不存在就创建一个
|
||||
# 检查recorder的上一个信号时间,如果没有那就从上线模型的共同最早时间开始出信号
|
||||
# 从recorder的上一个信号时间开始出信号,出到self.cur_time
|
||||
for model in self.online_models():
|
||||
|
||||
pass
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
|
||||
|
||||
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
|
||||
"""
|
||||
collecting artifact based on the collector and RollingGroup.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: the artifact dict after rolling ensemble
|
||||
"""
|
||||
artifact = ens_workflow(
|
||||
self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup()
|
||||
)
|
||||
return artifact
|
||||
|
||||
def first_train(self, task_configs: list):
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_configs and trained them.
|
||||
|
||||
Args:
|
||||
task_configs (list or dict): a list of task configs or a task config
|
||||
|
||||
Returns:
|
||||
Collector: a instance of a Collector.
|
||||
"""
|
||||
tasks = task_generator(
|
||||
tasks=task_configs,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
|
||||
self.prepare_signals(over_write=True)
|
||||
return self.get_collector()
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -264,7 +363,6 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
"""
|
||||
#TODO: max_test = self.cur_time
|
||||
latest_records, max_test = self.list_latest_recorders(
|
||||
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ class RecorderCollector(Collector):
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
artifacts_key = list(self.artifacts_path.keys())
|
||||
self._rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self._rec_filter_func = rec_filter_func
|
||||
|
||||
@@ -194,6 +194,15 @@ class RollingGen(TaskGen):
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if (
|
||||
self.ta.cal_interval(
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
t["dataset"]["kwargs"]["segments"][self.test_key][1],
|
||||
)
|
||||
< 0
|
||||
):
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
|
||||
prev_seg = segments
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user