mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
online serving v5
This commit is contained in:
@@ -6,11 +6,13 @@ from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.ensemble import RollingEnsemble
|
||||
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -96,24 +98,15 @@ def task_generating():
|
||||
return tasks
|
||||
|
||||
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks):
|
||||
print("========== task_storing ==========")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running():
|
||||
print("========== task_running ==========")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
def task_training(tasks):
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_group_key_func(recorder):
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
@@ -121,14 +114,14 @@ def task_collecting():
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = get_group_key_func(recorder)
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
artifact = collector.collect(RollingEnsemble(), get_group_key_func, rec_filter_func=my_filter)
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
@@ -147,8 +140,7 @@ def first_run():
|
||||
reset()
|
||||
|
||||
tasks = task_generating()
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_training(tasks)
|
||||
task_collecting()
|
||||
|
||||
latest_rec, _ = rolling_online_manager.list_latest_recorders()
|
||||
@@ -156,7 +148,7 @@ def first_run():
|
||||
|
||||
|
||||
def routine():
|
||||
print("========== after_day ==========")
|
||||
print("========== routine ==========")
|
||||
print_online_model()
|
||||
rolling_online_manager.routine()
|
||||
print_online_model()
|
||||
@@ -185,8 +177,10 @@ if __name__ == "__main__":
|
||||
|
||||
##########################################################################################
|
||||
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
|
||||
)
|
||||
task_manager = TaskManager(task_pool=task_pool)
|
||||
trainer = TrainerRM()
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
|
||||
)
|
||||
|
||||
fire.Fire()
|
||||
|
||||
@@ -10,6 +10,7 @@ from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
@@ -19,31 +20,57 @@ class OnlineManager(Serializable):
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, trainer: Trainer = None) -> None:
|
||||
self._trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""return the new tasks waiting for training."""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
|
||||
def prepare_new_models(self, tasks, *args, **kwargs):
|
||||
"""Use trainer to train a list of tasks and set the trained model to next_online.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
"""
|
||||
if not (tasks is None or len(tasks) == 0):
|
||||
if self._trainer is not None:
|
||||
new_models = self._trainer.train(tasks, *args, **kwargs)
|
||||
self.set_online_tag(self.NEXT_ONLINE_TAG, new_models)
|
||||
self.logger.info(
|
||||
f"Finished prepare {len(new_models)} new models and set them to `{self.NEXT_ONLINE_TAG}`."
|
||||
)
|
||||
else:
|
||||
self.logger.warn("No trainer to train new tasks.")
|
||||
|
||||
def update_online_pred(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
def set_online_tag(self, tag, *args, **kwargs):
|
||||
"""set `tag` to the model to sign whether online
|
||||
|
||||
Args:
|
||||
tag (str): the tags in ONLINE_TAG, NEXT_ONLINE_TAG, OFFLINE_TAG
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, *args, **kwargs):
|
||||
"""given a model and return its online tag"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, *args, **kwargs):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing."""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def routine(self, *args, **kwargs):
|
||||
"""The typical update process in a routine such as day by day or month by month"""
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(*args, **kwargs)
|
||||
tasks = self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(tasks, *args, **kwargs)
|
||||
self.update_online_pred(*args, **kwargs)
|
||||
self.reset_online_tag(*args, **kwargs)
|
||||
|
||||
@@ -54,7 +81,8 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = TrainerR()) -> None:
|
||||
super().__init__(trainer)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
@@ -98,27 +126,36 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
# FIXME: TaskManager不应该与onlinemanager强耦合
|
||||
"""An implementation of OnlineManager based on Rolling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, experiment_name: str, rolling_gen: RollingGen, task_manager: TaskManager, trainer=run_task
|
||||
self,
|
||||
experiment_name: str,
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = TrainerR(),
|
||||
) -> None:
|
||||
super().__init__(experiment_name)
|
||||
super().__init__(experiment_name, trainer)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.tm = task_manager
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.trainer = trainer
|
||||
|
||||
def prepare_signals(self):
|
||||
pass
|
||||
|
||||
def prepare_tasks(self):
|
||||
"""prepare new tasks based on new date.
|
||||
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
"""
|
||||
latest_records, max_test = self.list_latest_recorders(
|
||||
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
|
||||
)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest_recorders.")
|
||||
return
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return None
|
||||
calendar_latest = self.ta.last_date()
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
|
||||
old_tasks = []
|
||||
@@ -128,18 +165,20 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
old_tasks.append(task)
|
||||
new_tasks = task_generator(old_tasks, self.rg)
|
||||
self.tm.create_task(new_tasks)
|
||||
|
||||
def prepare_new_models(self):
|
||||
"""prepare(train) new models based on online model"""
|
||||
run_task(task_train, task_pool=self.tm.task_pool, experiment_name=self.exp_name)
|
||||
latest_records, _ = self.list_latest_recorders()
|
||||
# FIXME: 现有的流程,如果没有可更新的模型,仍会调用这个,导致会先将以前的模型设置成nextonline再去更新pred,但这个时候online已经没有了,pred无法更新
|
||||
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
|
||||
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")
|
||||
new_tasks_tmp = task_generator(old_tasks, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return None
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
"""find latest recorders based on test segments.
|
||||
|
||||
Args:
|
||||
rec_filter_func (Callable, optional): recorder filter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict, tuple: the latest recorders and the latest date of them
|
||||
"""
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
if len(recs_flt) == 0:
|
||||
return recs_flt, None
|
||||
|
||||
Reference in New Issue
Block a user