mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
finished update_online_pred demo
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import qlib
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.online import RollingOnlineManager
|
||||
from qlib.workflow.task.online import OnlineManager
|
||||
from qlib.config import REG_CN
|
||||
import fire
|
||||
from qlib.workflow import R
|
||||
@@ -54,16 +54,15 @@ task = {
|
||||
|
||||
def first_train(experiment_name="online_svr"):
|
||||
|
||||
rom = RollingOnlineManager(experiment_name)
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
|
||||
rom = OnlineManager(experiment_name)
|
||||
rom.reset_online_model(rid)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_svr"):
|
||||
|
||||
rom = RollingOnlineManager(experiment_name)
|
||||
rom = OnlineManager(experiment_name)
|
||||
|
||||
print("Here are the online models waiting for update:")
|
||||
for rid, rec in rom.list_online_model().items():
|
||||
|
||||
@@ -27,7 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
datahandler = dataset.handler
|
||||
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
|
||||
|
||||
@@ -7,21 +7,7 @@ from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
|
||||
|
||||
class OnlineManagement:
|
||||
def __init__(self, experiment_name):
|
||||
pass
|
||||
|
||||
def update_online_pred(self, recorder: Union[str, Recorder]):
|
||||
"""update the predictions of online models
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder : Union[str, Recorder]
|
||||
the id or the instance of Recorder
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_pred` method.")
|
||||
|
||||
class OnlineManager:
|
||||
def prepare_new_models(self, tasks: List[dict]):
|
||||
"""prepare(train) new models
|
||||
|
||||
@@ -33,20 +19,6 @@ class OnlineManagement:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
|
||||
|
||||
def reset_online_model(self, recorders: List[Union[str, Recorder]]):
|
||||
"""reset online model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorders : List[Union[str, Recorder]]
|
||||
a list of the recorder id or the instance
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_model` method.")
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagement):
|
||||
|
||||
ONLINE_TAG = "online_model"
|
||||
ONLINE_TAG_TRUE = "True"
|
||||
ONLINE_TAG_FALSE = "False"
|
||||
@@ -59,8 +31,7 @@ class RollingOnlineManager(OnlineManagement):
|
||||
experiment_name : str
|
||||
experiment name string
|
||||
"""
|
||||
super(RollingOnlineManager, self).__init__(experiment_name)
|
||||
self.logger = get_module_logger("RollingOnlineManager")
|
||||
self.logger = get_module_logger("OnlineManagement")
|
||||
self.exp_name = experiment_name
|
||||
self.tc = TaskCollector(experiment_name)
|
||||
|
||||
@@ -122,3 +93,16 @@ class RollingOnlineManager(OnlineManagement):
|
||||
mu = ModelUpdater(self.exp_name)
|
||||
cnt = mu.update_all_pred(self.online_filter)
|
||||
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManager):
|
||||
def prepare_new_models(self, tasks: List[dict]):
|
||||
"""prepare(train) new models
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict]
|
||||
a list of tasks
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user