mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Online Serving V8
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.online.simulator import OnlineSimulator
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
"""
|
||||
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
|
||||
@@ -19,7 +19,7 @@ The OnlineSimulator will focus on the simulating real updating routine of your o
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": None, # "2018-10-31",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
@@ -74,7 +74,7 @@ task_xgboost_config = {
|
||||
}
|
||||
|
||||
|
||||
class OnlineManagerExample:
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
@@ -86,6 +86,7 @@ class OnlineManagerExample:
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config], # , task_lgb_config]
|
||||
):
|
||||
"""
|
||||
init OnlineManagerExample.
|
||||
@@ -100,6 +101,7 @@ class OnlineManagerExample:
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
@@ -108,76 +110,49 @@ class OnlineManagerExample:
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) # The rolling tasks generator
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # The trainer based on (R)ecorder and Task(M)anager
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, modify_end_time=False
|
||||
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.collector = RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key) # The result collector
|
||||
self.grouper = RollingGroup() # Divide your results into different rolling group
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=self.rolling_gen,
|
||||
trainer=self.trainer,
|
||||
collector=self.collector,
|
||||
need_log=False,
|
||||
) # The OnlineManager based on Rolling
|
||||
self.onlinesimulator = OnlineSimulator(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
onlinemanager=self.rolling_online_manager,
|
||||
online_manager=self.rolling_online_manager,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@staticmethod
|
||||
def rec_key(recorder):
|
||||
"""
|
||||
given a Recorder and return its key to identify it
|
||||
|
||||
Args:
|
||||
recorder (Recorder): a instance of the Recorder
|
||||
|
||||
Returns:
|
||||
tuple: (model_key, rolling_key)
|
||||
"""
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def result_collecting(self):
|
||||
print("========== result collecting ==========")
|
||||
|
||||
# ens_workflow can help collect, group and ensemble results in a easy way
|
||||
artifact = ens_workflow(self.rolling_online_manager.get_collector(), self.grouper)
|
||||
print(artifact)
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this firstly to see the workflow in OnlineManager
|
||||
def first_train(self):
|
||||
print("========== first train ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=[self.rolling_gen], # generate different date segment
|
||||
)
|
||||
|
||||
self.rolling_online_manager.prepare_new_models(tasks=tasks, tag=RollingOnlineManager.ONLINE_TAG)
|
||||
self.result_collecting()
|
||||
self.rolling_online_manager.first_train(self.tasks)
|
||||
|
||||
# Run this secondly to see the simulating in OnlineSimulator
|
||||
def simulate(self):
|
||||
|
||||
print("========== simulate ==========")
|
||||
self.onlinesimulator.simulate()
|
||||
|
||||
self.result_collecting()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
|
||||
print("========== online models ==========")
|
||||
recs_dict = self.onlinesimulator.online_models()
|
||||
@@ -186,6 +161,9 @@ class OnlineManagerExample:
|
||||
for rec in recs:
|
||||
print(rec.info["id"])
|
||||
|
||||
print("========== online signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
# Run this to run all workflow automaticly
|
||||
def main(self):
|
||||
self.first_train()
|
||||
@@ -195,4 +173,4 @@ class OnlineManagerExample:
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automaticly with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineManagerExample)
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
|
||||
@@ -111,6 +111,11 @@ class RollingOnlineExample:
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import time
|
||||
from xxlimited import Str
|
||||
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -11,6 +14,63 @@ from qlib.model.base import Model
|
||||
import socket
|
||||
|
||||
|
||||
def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder:
|
||||
"""
|
||||
Begin a task training with starting a recorder and saving the task config.
|
||||
|
||||
Args:
|
||||
task_config (dict)
|
||||
experiment_name (str)
|
||||
|
||||
Returns:
|
||||
Recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"})
|
||||
recorder: Recorder = R.get_recorder()
|
||||
return recorder
|
||||
|
||||
|
||||
def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
|
||||
"""
|
||||
Finished task training with real model fitting and saving.
|
||||
|
||||
Args:
|
||||
rec (Recorder): This recorder will be resumed
|
||||
experiment_name (str)
|
||||
|
||||
Returns:
|
||||
Recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True):
|
||||
task_config = R.load_object("task")
|
||||
# model & dataset initiaiton
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# model training
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# This dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
rconf = {"recorder": rec}
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
R.set_tags(**{"train_status": "end_task_train"})
|
||||
return rec
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
task based training
|
||||
@@ -26,36 +86,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
----------
|
||||
Recorder : The instance of the recorder
|
||||
"""
|
||||
# model initiaiton
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
|
||||
# train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(hostname=socket.gethostname())
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# This dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
recorder: Recorder = R.get_recorder()
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
recorder = begin_task_train(task_config, experiment_name)
|
||||
recorder = end_task_train(recorder, experiment_name)
|
||||
return recorder
|
||||
|
||||
|
||||
@@ -64,14 +96,22 @@ class Trainer:
|
||||
The trainer which can train a list of model
|
||||
"""
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
"""Given a list of model definition, finished training and return the results of them.
|
||||
def train(self, tasks: list, *args, **kwargs):
|
||||
"""Given a list of model definition, begin a training and return the models.
|
||||
|
||||
Returns:
|
||||
list: a list of trained results
|
||||
list: a list of models
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `train` method.")
|
||||
|
||||
def end_train(self, models, *args, **kwargs):
|
||||
"""Given a list of models, finished something in the end of training if you need.
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""Trainer based on (R)ecorder.
|
||||
@@ -112,7 +152,15 @@ class TrainerRM(Trainer):
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
def train(
|
||||
self,
|
||||
tasks: list,
|
||||
train_func=None,
|
||||
before_status=TaskManager.STATUS_WAITING,
|
||||
after_status=TaskManager.STATUS_DONE,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
@@ -129,7 +177,15 @@ class TrainerRM(Trainer):
|
||||
train_func = self.train_func
|
||||
tm = TaskManager(task_pool=self.task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs)
|
||||
run_task(
|
||||
train_func,
|
||||
self.task_pool,
|
||||
experiment_name=self.experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
@@ -137,10 +193,96 @@ class TrainerRM(Trainer):
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainer(Trainer):
|
||||
def fake_train(self):
|
||||
self.fake_trained = []
|
||||
class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
|
||||
def train(self):
|
||||
for rec in self.fake_trained:
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
super().__init__(experiment_name, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.recs = []
|
||||
|
||||
def train(self, tasks: list, train_func, *args, **kwargs):
|
||||
"""
|
||||
Same as `train` of TrainerR, the results will be recorded in self.recs
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
self.recs = super().train(tasks, train_func=train_func, *args, **kwargs)
|
||||
return self.recs
|
||||
|
||||
def end_train(self, recs=None, end_train_func=None):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finished real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if recs is None:
|
||||
recs = copy.deepcopy(self.recs)
|
||||
# the models will be only trained once
|
||||
self.recs = []
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
for rec in recs:
|
||||
end_train_func(rec)
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
"""
|
||||
Same as `train` of TrainerRM, the results will be recorded in self.recs
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs)
|
||||
|
||||
def end_train(self, recs, end_train_func=None):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finished real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs..
|
||||
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
run_task(
|
||||
end_train_func,
|
||||
self.task_pool,
|
||||
experiment_name=self.experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
return recs
|
||||
|
||||
@@ -304,7 +304,7 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
|
||||
@@ -44,17 +44,21 @@ class OnlineManager(Serializable):
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.delay_signals = {}
|
||||
self.cur_time = None
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
def prepare_signals(self):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
After preparing signals, here is the method to get them.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
@@ -62,7 +66,7 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None):
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `tag`.
|
||||
|
||||
@@ -75,13 +79,14 @@ class OnlineManager(Serializable):
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
"""
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
if len(tasks) > 0:
|
||||
if self.trainer is not None:
|
||||
new_models = self.trainer.train(tasks)
|
||||
new_models = self.trainer.train(tasks, *args, **kwargs)
|
||||
if check_func(new_models):
|
||||
self.set_online_tag(tag, new_models)
|
||||
if self.need_log:
|
||||
@@ -89,13 +94,13 @@ class OnlineManager(Serializable):
|
||||
else:
|
||||
self.logger.warn("No trainer to train new tasks.")
|
||||
|
||||
def update_online_pred(self, *args, **kwargs):
|
||||
def update_online_pred(self):
|
||||
"""
|
||||
After the end of a routine, update the predictions of online models to latest.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
def set_online_tag(self, tag, *args, **kwargs):
|
||||
def set_online_tag(self, tag, recorder):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
@@ -104,15 +109,21 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, *args, **kwargs):
|
||||
def get_online_tag(self):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, *args, **kwargs):
|
||||
"""
|
||||
Offline all models and set the models to 'online'.
|
||||
def reset_online_tag(self, recorders=None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (List, optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
list: new online recorder. [] if there is no update.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
@@ -137,31 +148,46 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def run_delay_signals(self):
|
||||
def delay_prepare(self, rec_dict, *args, **kwargs):
|
||||
"""
|
||||
Prepare all signals if there are some dates waiting for prepare.
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for cur_time, params in self.delay_signals.items():
|
||||
self.cur_time = cur_time
|
||||
self.prepare_signals(*params[0], **params[1])
|
||||
self.delay_signals = {}
|
||||
for time_segment, recs_list in rec_dict.items():
|
||||
self.trainer.end_train(recs_list, *args, **kwargs)
|
||||
self.reset_online_tag(recs_list)
|
||||
self.prepare_signals()
|
||||
signal_max = self.get_signals().index.get_level_values("datetime").max()
|
||||
if time_segment[1] is not None and signal_max > time_segment[1]:
|
||||
raise ValueError(
|
||||
f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}"
|
||||
)
|
||||
|
||||
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
|
||||
"""
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
|
||||
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
|
||||
|
||||
Args:
|
||||
cur_time ([type], optional): [description]. Defaults to None.
|
||||
delay_prepare (bool, optional): [description]. Defaults to False.
|
||||
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
self.cur_time = cur_time # None for latest date
|
||||
self.update_online_pred()
|
||||
if not delay_prepare:
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
else:
|
||||
if cur_time is not None:
|
||||
self.delay_signals[cur_time] = (args, kwargs)
|
||||
else:
|
||||
raise ValueError("Can not delay prepare when cur_time is None")
|
||||
self.update_online_pred()
|
||||
self.prepare_signals()
|
||||
tasks = self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(tasks)
|
||||
self.prepare_new_models(tasks, *args, **kwargs)
|
||||
|
||||
return self.reset_online_tag()
|
||||
|
||||
@@ -185,8 +211,16 @@ class OnlineManagerR(OnlineManager):
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer=trainer, need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
self.signal_rec = None
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List])
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
@@ -195,6 +229,15 @@ class OnlineManagerR(OnlineManager):
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): a instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
|
||||
|
||||
@@ -202,7 +245,7 @@ class OnlineManagerR(OnlineManager):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (Union[List, Dict], optional):
|
||||
recorders (Union[Recorder, List], optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
@@ -225,7 +268,30 @@ class OnlineManagerR(OnlineManager):
|
||||
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
|
||||
return recorder
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
|
||||
|
||||
Returns:
|
||||
signals
|
||||
"""
|
||||
if self.signal_rec is None:
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder()
|
||||
signals = None
|
||||
try:
|
||||
signals = self.signal_rec.load_object("signals")
|
||||
except OSError:
|
||||
self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
|
||||
return signals
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return online models.
|
||||
|
||||
Returns:
|
||||
list: the list of online models
|
||||
"""
|
||||
return list(
|
||||
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values()
|
||||
)
|
||||
@@ -245,34 +311,35 @@ class OnlineManagerR(OnlineManager):
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
"""
|
||||
if self.signal_rec is None:
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder()
|
||||
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
recorder = R.get_recorder()
|
||||
pred = []
|
||||
pred = []
|
||||
try:
|
||||
old_signals = self.signal_rec.load_object("signals")
|
||||
except OSError:
|
||||
old_signals = None
|
||||
|
||||
try:
|
||||
old_signals = recorder.load_object("signals")
|
||||
except OSError:
|
||||
old_signals = None
|
||||
for rec in self.online_models():
|
||||
pred.append(rec.load_object("pred.pkl"))
|
||||
|
||||
for rec in self.online_models():
|
||||
pred.append(rec.load_object("pred.pkl"))
|
||||
|
||||
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
|
||||
signals = signals.sort_index()
|
||||
if old_signals is not None and not over_write:
|
||||
# signals = old_signals.reindex(signals.index).combine_first(signals)
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
|
||||
signals = signals.sort_index()
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
|
||||
recorder.save_objects(**{"signals": signals})
|
||||
self.signal_rec.save_objects(**{"signals": signals})
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
@@ -304,7 +371,9 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
|
||||
def get_collector(self, rec_key_func=None, rec_filter_func=None):
|
||||
"""
|
||||
get the instance of collector to collect results
|
||||
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
|
||||
Assumption: the models can be distinguished based on model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your own method or use another rec_key_func.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
@@ -353,10 +422,9 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
|
||||
self.prepare_signals(over_write=True)
|
||||
return self.get_collector()
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
def prepare_tasks(self):
|
||||
"""
|
||||
Prepare new tasks based on new date.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class OnlineSimulator:
|
||||
self,
|
||||
start_time,
|
||||
end_time,
|
||||
onlinemanager: OnlineManager,
|
||||
online_manager: OnlineManager,
|
||||
frequency="day",
|
||||
):
|
||||
"""
|
||||
@@ -28,15 +28,14 @@ class OnlineSimulator:
|
||||
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
|
||||
self.start_time = self.cal[0]
|
||||
self.end_time = self.cal[-1]
|
||||
self.olm = onlinemanager
|
||||
|
||||
self.olm = online_manager
|
||||
if len(self.cal) == 0:
|
||||
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
|
||||
|
||||
def simulate(self, *args, **kwargs):
|
||||
"""
|
||||
Starting from start time, this method will simulate every routine in OnlineManager.
|
||||
NOTE: Considering the parallel training, the signals will be perpared after all routine simulating.
|
||||
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
|
||||
Returns:
|
||||
Collector: the OnlineManager's collector
|
||||
@@ -54,12 +53,10 @@ class OnlineSimulator:
|
||||
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
|
||||
tmp_begin = cur_time
|
||||
prev_recorders = recorders
|
||||
|
||||
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
|
||||
# prepare signals again incase there is no trained model when call it
|
||||
self.olm.run_delay_signals()
|
||||
# finished perparing models (and pred) and signals
|
||||
self.olm.delay_prepare(self.rec_dict)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
|
||||
return self.olm.get_collector()
|
||||
|
||||
def online_models(self):
|
||||
|
||||
@@ -91,7 +91,7 @@ class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX, modify_end_time=True):
|
||||
"""
|
||||
Generate tasks for rolling
|
||||
|
||||
@@ -101,9 +101,12 @@ class RollingGen(TaskGen):
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, sliding)
|
||||
modify_end_time: bool
|
||||
Whether the data set configuration needs to be modified when the required scope exceeds the original data set scope
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.modify_end_time = modify_end_time
|
||||
# TODO: Ask pengrong to update future date in dataset
|
||||
self.ta = TimeAdjuster(future=True)
|
||||
|
||||
@@ -113,7 +116,6 @@ class RollingGen(TaskGen):
|
||||
def generate(self, task: dict):
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
# FIXME: only modify dataset layer, user need to change datahandler firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -196,7 +198,8 @@ class RollingGen(TaskGen):
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if (
|
||||
self.ta.cal_interval(
|
||||
self.modify_end_time
|
||||
and self.ta.cal_interval(
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
t["dataset"]["kwargs"]["segments"][self.test_key][1],
|
||||
)
|
||||
|
||||
@@ -174,11 +174,11 @@ class TaskManager:
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}):
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query.update({"status": self.STATUS_WAITING})
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
@@ -189,7 +189,7 @@ class TaskManager:
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}):
|
||||
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
"""
|
||||
fetch task from task_pool using query with contextmanager
|
||||
|
||||
@@ -202,7 +202,7 @@ class TaskManager:
|
||||
-------
|
||||
|
||||
"""
|
||||
task = self.fetch_task(query=query)
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
@@ -330,7 +330,15 @@ class TaskManager:
|
||||
return f"TaskManager({self.task_pool})"
|
||||
|
||||
|
||||
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
def run_task(
|
||||
task_func,
|
||||
task_pool,
|
||||
force_release=False,
|
||||
before_status=TaskManager.STATUS_WAITING,
|
||||
after_status=TaskManager.STATUS_DONE,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
|
||||
|
||||
@@ -352,16 +360,24 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
ever_run = False
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task() as task:
|
||||
with tm.safe_fetch_task(status=before_status) as task:
|
||||
if task is None:
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means?
|
||||
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
|
||||
# when fetching `WAITING` task, use task_def to train
|
||||
if before_status == TaskManager.STATUS_WAITING:
|
||||
param = task["def"]
|
||||
# when fetching `PART_DONE` task, use task_res to train for the result has been saved
|
||||
elif before_status == TaskManager.STATUS_PART_DONE:
|
||||
param = task["res"]
|
||||
else:
|
||||
res = task_func(task["def"], *args, **kwargs)
|
||||
tm.commit_task_res(task, res)
|
||||
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, param, *args, **kwargs).result()
|
||||
else:
|
||||
res = task_func(param, *args, **kwargs)
|
||||
tm.commit_task_res(task, res, status=after_status)
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
Reference in New Issue
Block a user