1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Online Serving V8

This commit is contained in:
lzh222333
2021-04-26 09:31:47 +00:00
parent 319396c815
commit 0058f7d0dc
8 changed files with 368 additions and 159 deletions

View File

@@ -1,14 +1,14 @@
import fire
import qlib
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.online.simulator import OnlineSimulator
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.utils import list_recorders
"""
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
@@ -19,7 +19,7 @@ The OnlineSimulator will focus on the simulating real updating routine of your o
data_handler_config = {
"start_time": "2018-01-01",
"end_time": None, # "2018-10-31",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
@@ -74,7 +74,7 @@ task_xgboost_config = {
}
class OnlineManagerExample:
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
@@ -86,6 +86,7 @@ class OnlineManagerExample:
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config], # , task_lgb_config]
):
"""
init OnlineManagerExample.
@@ -100,6 +101,7 @@ class OnlineManagerExample:
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
self.exp_name = exp_name
self.task_pool = task_pool
@@ -108,76 +110,49 @@ class OnlineManagerExample:
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) # The rolling tasks generator
self.trainer = TrainerRM(self.exp_name, self.task_pool) # The trainer based on (R)ecorder and Task(M)anager
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, modify_end_time=False
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
self.collector = RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key) # The result collector
self.grouper = RollingGroup() # Divide your results into different rolling group
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=self.rolling_gen,
trainer=self.trainer,
collector=self.collector,
need_log=False,
) # The OnlineManager based on Rolling
self.onlinesimulator = OnlineSimulator(
start_time=start_time,
end_time=end_time,
onlinemanager=self.rolling_online_manager,
online_manager=self.rolling_online_manager,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
@staticmethod
def rec_key(recorder):
"""
given a Recorder and return its key to identify it
Args:
recorder (Recorder): a instance of the Recorder
Returns:
tuple: (model_key, rolling_key)
"""
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def result_collecting(self):
print("========== result collecting ==========")
# ens_workflow can help collect, group and ensemble results in a easy way
artifact = ens_workflow(self.rolling_online_manager.get_collector(), self.grouper)
print(artifact)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
exp.delete_recorder(rid)
# Run this firstly to see the workflow in OnlineManager
def first_train(self):
print("========== first train ==========")
self.reset()
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=[self.rolling_gen], # generate different date segment
)
self.rolling_online_manager.prepare_new_models(tasks=tasks, tag=RollingOnlineManager.ONLINE_TAG)
self.result_collecting()
self.rolling_online_manager.first_train(self.tasks)
# Run this secondly to see the simulating in OnlineSimulator
def simulate(self):
print("========== simulate ==========")
self.onlinesimulator.simulate()
self.result_collecting()
print(self.rolling_online_manager.collect_artifact())
print("========== online models ==========")
recs_dict = self.onlinesimulator.online_models()
@@ -186,6 +161,9 @@ class OnlineManagerExample:
for rec in recs:
print(rec.info["id"])
print("========== online signals ==========")
print(self.rolling_online_manager.get_signals())
# Run this to run all workflow automaticly
def main(self):
self.first_train()
@@ -195,4 +173,4 @@ class OnlineManagerExample:
if __name__ == "__main__":
## to run all workflow automaticly with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineManagerExample)
fire.Fire(OnlineSimulationExample)

View File

@@ -111,6 +111,11 @@ class RollingOnlineExample:
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
exp.delete_recorder(rid)
def first_run(self):
print("========== first_run ==========")
self.reset()

View File

@@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import time
from xxlimited import Str
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
@@ -11,6 +14,63 @@ from qlib.model.base import Model
import socket
def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder:
"""
Begin a task training with starting a recorder and saving the task config.
Args:
task_config (dict)
experiment_name (str)
Returns:
Recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"})
recorder: Recorder = R.get_recorder()
return recorder
def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
"""
Finished task training with real model fitting and saving.
Args:
rec (Recorder): This recorder will be resumed
experiment_name (str)
Returns:
Recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True):
task_config = R.load_object("task")
# model & dataset initiaiton
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# This dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()
R.set_tags(**{"train_status": "end_task_train"})
return rec
def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
task based training
@@ -26,36 +86,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
----------
Recorder : The instance of the recorder
"""
# model initiaiton
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# start exp
with R.start(experiment_name=experiment_name):
# train model
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(hostname=socket.gethostname())
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# This dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
recorder: Recorder = R.get_recorder()
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": recorder}
else:
rconf = {"recorder": recorder}
r = cls(**kwargs, **rconf)
r.generate()
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
@@ -64,14 +96,22 @@ class Trainer:
The trainer which can train a list of model
"""
def train(self, *args, **kwargs):
"""Given a list of model definition, finished training and return the results of them.
def train(self, tasks: list, *args, **kwargs):
"""Given a list of model definition, begin a training and return the models.
Returns:
list: a list of trained results
list: a list of models
"""
raise NotImplementedError(f"Please implement the `train` method.")
def end_train(self, models, *args, **kwargs):
"""Given a list of models, finished something in the end of training if you need.
Returns:
list: a list of models
"""
pass
class TrainerR(Trainer):
"""Trainer based on (R)ecorder.
@@ -112,7 +152,15 @@ class TrainerRM(Trainer):
self.task_pool = task_pool
self.train_func = train_func
def train(self, tasks: list, train_func=None, *args, **kwargs):
def train(
self,
tasks: list,
train_func=None,
before_status=TaskManager.STATUS_WAITING,
after_status=TaskManager.STATUS_DONE,
*args,
**kwargs,
):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
@@ -129,7 +177,15 @@ class TrainerRM(Trainer):
train_func = self.train_func
tm = TaskManager(task_pool=self.task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs)
run_task(
train_func,
self.task_pool,
experiment_name=self.experiment_name,
before_status=before_status,
after_status=after_status,
*args,
**kwargs,
)
recs = []
for _id in _id_list:
@@ -137,10 +193,96 @@ class TrainerRM(Trainer):
return recs
class DelayTrainer(Trainer):
def fake_train(self):
self.fake_trained = []
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
def train(self):
for rec in self.fake_trained:
pass
"""
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.recs = []
def train(self, tasks: list, train_func, *args, **kwargs):
"""
Same as `train` of TrainerR, the results will be recorded in self.recs
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
Returns:
list: a list of Recorders
"""
self.recs = super().train(tasks, train_func=train_func, *args, **kwargs)
return self.recs
def end_train(self, recs=None, end_train_func=None):
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finished real data loading and model fitting.
Args:
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
Returns:
list: a list of Recorders
"""
if recs is None:
recs = copy.deepcopy(self.recs)
# the models will be only trained once
self.recs = []
if end_train_func is None:
end_train_func = self.end_train_func
for rec in recs:
end_train_func(rec)
return recs
class DelayTrainerRM(TrainerRM):
"""
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train):
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
def train(self, tasks: list, train_func=None, *args, **kwargs):
"""
Same as `train` of TrainerRM, the results will be recorded in self.recs
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
Returns:
list: a list of Recorders
"""
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs)
def end_train(self, recs, end_train_func=None):
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finished real data loading and model fitting.
Args:
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs..
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
Returns:
list: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
run_task(
end_train_func,
self.task_pool,
experiment_name=self.experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)
return recs

View File

@@ -304,7 +304,7 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
"""
Method for retrieving a recorder.

View File

@@ -44,17 +44,21 @@ class OnlineManager(Serializable):
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.delay_signals = {}
self.cur_time = None
def prepare_signals(self, *args, **kwargs):
def prepare_signals(self):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
Must use `pass` even though there is nothing to do.
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def get_signals(self):
"""
After preparing signals, here is the method to get them.
"""
raise NotImplementedError(f"Please implement the `get_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""
After the end of a routine, check whether we need to prepare and train some new tasks.
@@ -62,7 +66,7 @@ class OnlineManager(Serializable):
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None):
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs):
"""
Use trainer to train a list of tasks and set the trained model to `tag`.
@@ -75,13 +79,14 @@ class OnlineManager(Serializable):
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
if check_func is None:
check_func = lambda x: True
if len(tasks) > 0:
if self.trainer is not None:
new_models = self.trainer.train(tasks)
new_models = self.trainer.train(tasks, *args, **kwargs)
if check_func(new_models):
self.set_online_tag(tag, new_models)
if self.need_log:
@@ -89,13 +94,13 @@ class OnlineManager(Serializable):
else:
self.logger.warn("No trainer to train new tasks.")
def update_online_pred(self, *args, **kwargs):
def update_online_pred(self):
"""
After the end of a routine, update the predictions of online models to latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
def set_online_tag(self, tag, *args, **kwargs):
def set_online_tag(self, tag, recorder):
"""
Set `tag` to the model to sign whether online.
@@ -104,15 +109,21 @@ class OnlineManager(Serializable):
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, *args, **kwargs):
def get_online_tag(self):
"""
Given a model and return its online tag.
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, *args, **kwargs):
"""
Offline all models and set the models to 'online'.
def reset_online_tag(self, recorders=None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (List, optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
list: new online recorder. [] if there is no update.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
@@ -137,31 +148,46 @@ class OnlineManager(Serializable):
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
def run_delay_signals(self):
def delay_prepare(self, rec_dict, *args, **kwargs):
"""
Prepare all signals if there are some dates waiting for prepare.
Prepare all models and signals if there are something waiting for prepare.
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
Args:
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
for cur_time, params in self.delay_signals.items():
self.cur_time = cur_time
self.prepare_signals(*params[0], **params[1])
self.delay_signals = {}
for time_segment, recs_list in rec_dict.items():
self.trainer.end_train(recs_list, *args, **kwargs)
self.reset_online_tag(recs_list)
self.prepare_signals()
signal_max = self.get_signals().index.get_level_values("datetime").max()
if time_segment[1] is not None and signal_max > time_segment[1]:
raise ValueError(
f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}"
)
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
"""
The typical update process after a routine, such as day by day or month by month.
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
Args:
cur_time ([type], optional): [description]. Defaults to None.
delay_prepare (bool, optional): [description]. Defaults to False.
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
Returns:
[type]: [description]
"""
self.cur_time = cur_time # None for latest date
self.update_online_pred()
if not delay_prepare:
self.prepare_signals(*args, **kwargs)
else:
if cur_time is not None:
self.delay_signals[cur_time] = (args, kwargs)
else:
raise ValueError("Can not delay prepare when cur_time is None")
self.update_online_pred()
self.prepare_signals()
tasks = self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(tasks)
self.prepare_new_models(tasks, *args, **kwargs)
return self.reset_online_tag()
@@ -185,8 +211,16 @@ class OnlineManagerR(OnlineManager):
trainer = TrainerR(experiment_name)
super().__init__(trainer=trainer, need_log=need_log)
self.exp_name = experiment_name
self.signal_rec = None
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List])
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
@@ -195,6 +229,15 @@ class OnlineManagerR(OnlineManager):
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder):
"""
Given a model and return its online tag.
Args:
recorder (Recorder): a instance of recorder
Returns:
str: the tag
"""
tags = recorder.list_tags()
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
@@ -202,7 +245,7 @@ class OnlineManagerR(OnlineManager):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (Union[List, Dict], optional):
recorders (Union[Recorder, List], optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
@@ -225,7 +268,30 @@ class OnlineManagerR(OnlineManager):
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
return recorder
def get_signals(self):
"""
get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
Returns:
signals
"""
if self.signal_rec is None:
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder()
signals = None
try:
signals = self.signal_rec.load_object("signals")
except OSError:
self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
return signals
def online_models(self):
"""
Return online models.
Returns:
list: the list of online models
"""
return list(
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values()
)
@@ -245,34 +311,35 @@ class OnlineManagerR(OnlineManager):
"""
Average the predictions of online models and offer a trading signals every routine.
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
Even if the latest signal already exists, the latest calculation result will be overwritten.
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
"""
if self.signal_rec is None:
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder()
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
recorder = R.get_recorder()
pred = []
pred = []
try:
old_signals = self.signal_rec.load_object("signals")
except OSError:
old_signals = None
try:
old_signals = recorder.load_object("signals")
except OSError:
old_signals = None
for rec in self.online_models():
pred.append(rec.load_object("pred.pkl"))
for rec in self.online_models():
pred.append(rec.load_object("pred.pkl"))
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
signals = signals.sort_index()
if old_signals is not None and not over_write:
# signals = old_signals.reindex(signals.index).combine_first(signals)
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
signals = signals.sort_index()
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
if self.need_log:
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
recorder.save_objects(**{"signals": signals})
self.signal_rec.save_objects(**{"signals": signals})
class RollingOnlineManager(OnlineManagerR):
@@ -304,7 +371,9 @@ class RollingOnlineManager(OnlineManagerR):
def get_collector(self, rec_key_func=None, rec_filter_func=None):
"""
get the instance of collector to collect results
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
Assumption: the models can be distinguished based on model name and rolling test segments.
If you do not want this assumption, please implement your own method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
@@ -353,10 +422,9 @@ class RollingOnlineManager(OnlineManagerR):
generators=self.rg, # generate different date segment
)
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
self.prepare_signals(over_write=True)
return self.get_collector()
def prepare_tasks(self, *args, **kwargs):
def prepare_tasks(self):
"""
Prepare new tasks based on new date.

View File

@@ -12,7 +12,7 @@ class OnlineSimulator:
self,
start_time,
end_time,
onlinemanager: OnlineManager,
online_manager: OnlineManager,
frequency="day",
):
"""
@@ -28,15 +28,14 @@ class OnlineSimulator:
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
self.start_time = self.cal[0]
self.end_time = self.cal[-1]
self.olm = onlinemanager
self.olm = online_manager
if len(self.cal) == 0:
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
def simulate(self, *args, **kwargs):
"""
Starting from start time, this method will simulate every routine in OnlineManager.
NOTE: Considering the parallel training, the signals will be perpared after all routine simulating.
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
Returns:
Collector: the OnlineManager's collector
@@ -54,12 +53,10 @@ class OnlineSimulator:
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
tmp_begin = cur_time
prev_recorders = recorders
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
# prepare signals again incase there is no trained model when call it
self.olm.run_delay_signals()
# finished perparing models (and pred) and signals
self.olm.delay_prepare(self.rec_dict)
self.logger.info(f"Finished preparing signals")
return self.olm.get_collector()
def online_models(self):

View File

@@ -91,7 +91,7 @@ class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
def __init__(self, step: int = 40, rtype: str = ROLL_EX, modify_end_time=True):
"""
Generate tasks for rolling
@@ -101,9 +101,12 @@ class RollingGen(TaskGen):
step to rolling
rtype : str
rolling type (expanding, sliding)
modify_end_time: bool
Whether the data set configuration needs to be modified when the required scope exceeds the original data set scope
"""
self.step = step
self.rtype = rtype
self.modify_end_time = modify_end_time
# TODO: Ask pengrong to update future date in dataset
self.ta = TimeAdjuster(future=True)
@@ -113,7 +116,6 @@ class RollingGen(TaskGen):
def generate(self, task: dict):
"""
Converting the task into a rolling task.
# FIXME: only modify dataset layer, user need to change datahandler firstly.
Parameters
----------
@@ -196,7 +198,8 @@ class RollingGen(TaskGen):
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
# if end_time < the end of test_segments, then change end_time to allow load more data
if (
self.ta.cal_interval(
self.modify_end_time
and self.ta.cal_interval(
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
t["dataset"]["kwargs"]["segments"][self.test_key][1],
)

View File

@@ -174,11 +174,11 @@ class TaskManager:
return _id_list
def fetch_task(self, query={}):
def fetch_task(self, query={}, status=STATUS_WAITING):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query.update({"status": self.STATUS_WAITING})
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
@@ -189,7 +189,7 @@ class TaskManager:
return self._decode_task(task)
@contextmanager
def safe_fetch_task(self, query={}):
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
"""
fetch task from task_pool using query with contextmanager
@@ -202,7 +202,7 @@ class TaskManager:
-------
"""
task = self.fetch_task(query=query)
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
@@ -330,7 +330,15 @@ class TaskManager:
return f"TaskManager({self.task_pool})"
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
def run_task(
task_func,
task_pool,
force_release=False,
before_status=TaskManager.STATUS_WAITING,
after_status=TaskManager.STATUS_DONE,
*args,
**kwargs,
):
"""
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
@@ -352,16 +360,24 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
ever_run = False
while True:
with tm.safe_fetch_task() as task:
with tm.safe_fetch_task(status=before_status) as task:
if task is None:
break
get_module_logger("run_task").info(task["def"])
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means?
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
# when fetching `WAITING` task, use task_def to train
if before_status == TaskManager.STATUS_WAITING:
param = task["def"]
# when fetching `PART_DONE` task, use task_res to train for the result has been saved
elif before_status == TaskManager.STATUS_PART_DONE:
param = task["res"]
else:
res = task_func(task["def"], *args, **kwargs)
tm.commit_task_res(task, res)
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, param, *args, **kwargs).result()
else:
res = task_func(param, *args, **kwargs)
tm.commit_task_res(task, res, status=after_status)
ever_run = True
return ever_run