1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

online serving v5

This commit is contained in:
lzh222333
2021-04-02 07:09:29 +00:00
parent bd7a1c11b9
commit 431a9c92c1
2 changed files with 78 additions and 45 deletions

View File

@@ -6,11 +6,13 @@ from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow import R
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.ensemble import RollingEnsemble
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.task.utils import list_recorders
from qlib.model.trainer import TrainerRM
from qlib.model.ens.group import RollingGroup
data_handler_config = {
"start_time": "2013-01-01",
@@ -96,24 +98,15 @@ def task_generating():
return tasks
# This part corresponds to "Task Storing" in the document
def task_storing(tasks):
print("========== task_storing ==========")
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
# This part corresponds to "Task Running" in the document
def task_running():
print("========== task_running ==========")
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
def task_training(tasks):
trainer.train(tasks, exp_name, task_pool)
# This part corresponds to "Task Collecting" in the document
def task_collecting():
print("========== task_collecting ==========")
def get_group_key_func(recorder):
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
@@ -121,14 +114,14 @@ def task_collecting():
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = get_group_key_func(recorder)
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
artifact = collector.collect(RollingEnsemble(), get_group_key_func, rec_filter_func=my_filter)
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
)
print(artifact)
@@ -147,8 +140,7 @@ def first_run():
reset()
tasks = task_generating()
task_storing(tasks)
task_running()
task_training(tasks)
task_collecting()
latest_rec, _ = rolling_online_manager.list_latest_recorders()
@@ -156,7 +148,7 @@ def first_run():
def routine():
print("========== after_day ==========")
print("========== routine ==========")
print_online_model()
rolling_online_manager.routine()
print_online_model()
@@ -185,8 +177,10 @@ if __name__ == "__main__":
##########################################################################################
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
)
task_manager = TaskManager(task_pool=task_pool)
trainer = TrainerRM()
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
)
fire.Fire()

View File

@@ -10,6 +10,7 @@ from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
from qlib.model.trainer import Trainer, TrainerR
class OnlineManager(Serializable):
@@ -19,31 +20,57 @@ class OnlineManager(Serializable):
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, trainer: Trainer = None) -> None:
self._trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
def prepare_signals(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""return the new tasks waiting for training."""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
def prepare_new_models(self, tasks, *args, **kwargs):
"""Use trainer to train a list of tasks and set the trained model to next_online.
Args:
tasks (list): a list of tasks.
"""
if not (tasks is None or len(tasks) == 0):
if self._trainer is not None:
new_models = self._trainer.train(tasks, *args, **kwargs)
self.set_online_tag(self.NEXT_ONLINE_TAG, new_models)
self.logger.info(
f"Finished prepare {len(new_models)} new models and set them to `{self.NEXT_ONLINE_TAG}`."
)
else:
self.logger.warn("No trainer to train new tasks.")
def update_online_pred(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
def set_online_tag(self, tag, *args, **kwargs):
"""set `tag` to the model to sign whether online
Args:
tag (str): the tags in ONLINE_TAG, NEXT_ONLINE_TAG, OFFLINE_TAG
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, *args, **kwargs):
"""given a model and return its online tag"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, *args, **kwargs):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing."""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def routine(self, *args, **kwargs):
"""The typical update process in a routine such as day by day or month by month"""
self.prepare_signals(*args, **kwargs)
self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(*args, **kwargs)
tasks = self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(tasks, *args, **kwargs)
self.update_online_pred(*args, **kwargs)
self.reset_online_tag(*args, **kwargs)
@@ -54,7 +81,8 @@ class OnlineManagerR(OnlineManager):
"""
def __init__(self, experiment_name: str) -> None:
def __init__(self, experiment_name: str, trainer: Trainer = TrainerR()) -> None:
super().__init__(trainer)
self.logger = get_module_logger(self.__class__.__name__)
self.exp_name = experiment_name
@@ -98,27 +126,36 @@ class OnlineManagerR(OnlineManager):
class RollingOnlineManager(OnlineManagerR):
# FIXME: TaskManager不应该与onlinemanager强耦合
"""An implementation of OnlineManager based on Rolling.
"""
def __init__(
self, experiment_name: str, rolling_gen: RollingGen, task_manager: TaskManager, trainer=run_task
self,
experiment_name: str,
rolling_gen: RollingGen,
trainer: Trainer = TrainerR(),
) -> None:
super().__init__(experiment_name)
super().__init__(experiment_name, trainer)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.tm = task_manager
self.logger = get_module_logger(self.__class__.__name__)
self.trainer = trainer
def prepare_signals(self):
pass
def prepare_tasks(self):
"""prepare new tasks based on new date.
Returns:
list: a list of new tasks.
"""
latest_records, max_test = self.list_latest_recorders(
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
)
if max_test is None:
self.logger.warn(f"No latest_recorders.")
return
self.logger.warn(f"No latest online recorders, no new tasks.")
return None
calendar_latest = self.ta.last_date()
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
old_tasks = []
@@ -128,18 +165,20 @@ class RollingOnlineManager(OnlineManagerR):
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
old_tasks.append(task)
new_tasks = task_generator(old_tasks, self.rg)
self.tm.create_task(new_tasks)
def prepare_new_models(self):
"""prepare(train) new models based on online model"""
run_task(task_train, task_pool=self.tm.task_pool, experiment_name=self.exp_name)
latest_records, _ = self.list_latest_recorders()
# FIXME: 现有的流程如果没有可更新的模型仍会调用这个导致会先将以前的模型设置成nextonline再去更新pred但这个时候online已经没有了pred无法更新
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")
new_tasks_tmp = task_generator(old_tasks, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return None
def list_latest_recorders(self, rec_filter_func=None):
"""find latest recorders based on test segments.
Args:
rec_filter_func (Callable, optional): recorder filter. Defaults to None.
Returns:
dict, tuple: the latest recorders and the latest date of them
"""
recs_flt = list_recorders(self.exp_name, rec_filter_func)
if len(recs_flt) == 0:
return recs_flt, None