diff --git a/examples/taskmanager/update_online_pred.py b/examples/taskmanager/update_online_pred.py index 016336c68..a24b38889 100644 --- a/examples/taskmanager/update_online_pred.py +++ b/examples/taskmanager/update_online_pred.py @@ -1,6 +1,6 @@ import qlib from qlib.model.trainer import task_train -from qlib.workflow.task.online import RollingOnlineManager +from qlib.workflow.task.online import OnlineManager from qlib.config import REG_CN import fire from qlib.workflow import R @@ -54,16 +54,15 @@ task = { def first_train(experiment_name="online_svr"): - rom = RollingOnlineManager(experiment_name) - rid = task_train(task_config=task, experiment_name=experiment_name) + rom = OnlineManager(experiment_name) rom.reset_online_model(rid) def update_online_pred(experiment_name="online_svr"): - rom = RollingOnlineManager(experiment_name) + rom = OnlineManager(experiment_name) print("Here are the online models waiting for update:") for rid, rec in rom.list_online_model().items(): diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index b6d4de6e2..5c5609eb0 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -27,7 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: model = init_instance_by_config(task_config["model"]) dataset = init_instance_by_config(task_config["dataset"]) datahandler = dataset.handler - + # start exp with R.start(experiment_name=experiment_name): diff --git a/qlib/workflow/task/online.py b/qlib/workflow/task/online.py index 72d349122..f2b8e5706 100644 --- a/qlib/workflow/task/online.py +++ b/qlib/workflow/task/online.py @@ -7,21 +7,7 @@ from qlib.workflow.task.collect import TaskCollector from qlib.workflow.task.update import ModelUpdater -class OnlineManagement: - def __init__(self, experiment_name): - pass - - def update_online_pred(self, recorder: Union[str, Recorder]): - """update the predictions of online models - - Parameters - ---------- - recorder : Union[str, Recorder] - the id or the instance of Recorder - - """ - raise NotImplementedError(f"Please implement the `update_pred` method.") - +class OnlineManager: def prepare_new_models(self, tasks: List[dict]): """prepare(train) new models @@ -33,20 +19,6 @@ class OnlineManagement: """ raise NotImplementedError(f"Please implement the `prepare_new_models` method.") - def reset_online_model(self, recorders: List[Union[str, Recorder]]): - """reset online model - - Parameters - ---------- - recorders : List[Union[str, Recorder]] - a list of the recorder id or the instance - - """ - raise NotImplementedError(f"Please implement the `reset_online_model` method.") - - -class RollingOnlineManager(OnlineManagement): - ONLINE_TAG = "online_model" ONLINE_TAG_TRUE = "True" ONLINE_TAG_FALSE = "False" @@ -59,8 +31,7 @@ class RollingOnlineManager(OnlineManagement): experiment_name : str experiment name string """ - super(RollingOnlineManager, self).__init__(experiment_name) - self.logger = get_module_logger("RollingOnlineManager") + self.logger = get_module_logger("OnlineManagement") self.exp_name = experiment_name self.tc = TaskCollector(experiment_name) @@ -122,3 +93,16 @@ class RollingOnlineManager(OnlineManagement): mu = ModelUpdater(self.exp_name) cnt = mu.update_all_pred(self.online_filter) self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.") + + +class RollingOnlineManager(OnlineManager): + def prepare_new_models(self, tasks: List[dict]): + """prepare(train) new models + + Parameters + ---------- + tasks : List[dict] + a list of tasks + + """ + pass