1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

finished update_online_pred demo

This commit is contained in:
lzh222333
2021-03-16 02:43:12 +00:00
parent e3730b32d7
commit 5953365af3
3 changed files with 19 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
import qlib
from qlib.model.trainer import task_train
from qlib.workflow.task.online import RollingOnlineManager
from qlib.workflow.task.online import OnlineManager
from qlib.config import REG_CN
import fire
from qlib.workflow import R
@@ -54,16 +54,15 @@ task = {
def first_train(experiment_name="online_svr"):
rom = RollingOnlineManager(experiment_name)
rid = task_train(task_config=task, experiment_name=experiment_name)
rom = OnlineManager(experiment_name)
rom.reset_online_model(rid)
def update_online_pred(experiment_name="online_svr"):
rom = RollingOnlineManager(experiment_name)
rom = OnlineManager(experiment_name)
print("Here are the online models waiting for update:")
for rid, rec in rom.list_online_model().items():

View File

@@ -27,7 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
model = init_instance_by_config(task_config["model"])
dataset = init_instance_by_config(task_config["dataset"])
datahandler = dataset.handler
# start exp
with R.start(experiment_name=experiment_name):

View File

@@ -7,21 +7,7 @@ from qlib.workflow.task.collect import TaskCollector
from qlib.workflow.task.update import ModelUpdater
class OnlineManagement:
def __init__(self, experiment_name):
pass
def update_online_pred(self, recorder: Union[str, Recorder]):
"""update the predictions of online models
Parameters
----------
recorder : Union[str, Recorder]
the id or the instance of Recorder
"""
raise NotImplementedError(f"Please implement the `update_pred` method.")
class OnlineManager:
def prepare_new_models(self, tasks: List[dict]):
"""prepare(train) new models
@@ -33,20 +19,6 @@ class OnlineManagement:
"""
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
def reset_online_model(self, recorders: List[Union[str, Recorder]]):
"""reset online model
Parameters
----------
recorders : List[Union[str, Recorder]]
a list of the recorder id or the instance
"""
raise NotImplementedError(f"Please implement the `reset_online_model` method.")
class RollingOnlineManager(OnlineManagement):
ONLINE_TAG = "online_model"
ONLINE_TAG_TRUE = "True"
ONLINE_TAG_FALSE = "False"
@@ -59,8 +31,7 @@ class RollingOnlineManager(OnlineManagement):
experiment_name : str
experiment name string
"""
super(RollingOnlineManager, self).__init__(experiment_name)
self.logger = get_module_logger("RollingOnlineManager")
self.logger = get_module_logger("OnlineManagement")
self.exp_name = experiment_name
self.tc = TaskCollector(experiment_name)
@@ -122,3 +93,16 @@ class RollingOnlineManager(OnlineManagement):
mu = ModelUpdater(self.exp_name)
cnt = mu.update_all_pred(self.online_filter)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
class RollingOnlineManager(OnlineManager):
def prepare_new_models(self, tasks: List[dict]):
"""prepare(train) new models
Parameters
----------
tasks : List[dict]
a list of tasks
"""
pass