diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index fe6f0db6f..f7f9b62d5 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -1,18 +1,13 @@ -from typing import Dict, Union, List -from qlib import get_module_logger -from qlib.workflow import R -from qlib.model.trainer import task_train -from qlib.workflow.recorder import MLflowRecorder, Recorder -from qlib.workflow.online.update import PredUpdater, RecordUpdater -from qlib.workflow.task.collect import Collector -from qlib.workflow.task.utils import TimeAdjuster -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager -from qlib.workflow.task.manage import run_task -from qlib.workflow.task.utils import list_recorders -from qlib.utils.serial import Serializable -from qlib.model.trainer import Trainer, TrainerR from copy import deepcopy +from typing import Dict, List, Union +from qlib import get_module_logger +from qlib.data.data import D +from qlib.model.trainer import Trainer, TrainerR, task_train +from qlib.workflow.online.update import PredUpdater +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.collect import Collector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.utils import TimeAdjuster, list_recorders class OnlineManager: @@ -63,6 +58,7 @@ class OnlineManager: `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` `OFFLINE_TAG` for train but offline those models """ + # TODO: 回调 if not (tasks is None or len(tasks) == 0): if self.trainer is not None: new_models = self.trainer.train(tasks) @@ -158,7 +154,8 @@ class OnlineManagerR(OnlineManager): collector (Collector, optional): a instance of Collector. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. """ - trainer = TrainerR(experiment_name) + if trainer is None: + trainer = TrainerR(experiment_name) super().__init__(trainer=trainer, collector=collector, need_log=need_log) self.exp_name = experiment_name @@ -239,7 +236,8 @@ class RollingOnlineManager(OnlineManagerR): collector (Collector, optional): a instance of Collector. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. """ - trainer = TrainerR(experiment_name) + if trainer is None: + trainer = TrainerR(experiment_name) super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log) self.ta = TimeAdjuster() self.rg = rolling_gen @@ -247,9 +245,17 @@ class RollingOnlineManager(OnlineManagerR): def prepare_signals(self, *args, **kwargs): """ + Average the online models prediction and save them into a recorder + + Must use `pass` even though there is nothing to do. """ - pass + # 检查recorder是否存在,如果不存在就创建一个 + # 检查recorder的上一个信号时间,如果没有那就从上线模型的共同最早时间开始出信号 + # 从recorder的上一个信号时间开始出信号,出到self.cur_time + for model in self.online_models(): + + pass def prepare_tasks(self, *args, **kwargs): """ @@ -258,17 +264,17 @@ class RollingOnlineManager(OnlineManagerR): Returns: list: a list of new tasks. """ - self.ta.set_end_time(self.cur_time) + #TODO: max_test = self.cur_time latest_records, max_test = self.list_latest_recorders( lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG ) if max_test is None: self.logger.warn(f"No latest online recorders, no new tasks.") return [] - calendar_latest = self.ta.last_date() if self.cur_time is None else self.cur_time + calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time if self.need_log: self.logger.info( - f"The interval between current time and last rolling test begin time is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" + f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" ) if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: old_tasks = [] diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index a6f0aeefe..5b58360d8 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -73,9 +73,7 @@ class PredUpdater(RecordUpdater): Update the prediction in the Recorder """ - LATEST = "__latest" - - def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day", need_log=True): + def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True): """ Parameters ---------- @@ -95,8 +93,7 @@ class PredUpdater(RecordUpdater): self.freq = freq self.rmdl = RMDLoader(rec=record) - # FIXME: why we need LATEST? can we use to_date=None instead? - if to_date == self.LATEST or to_date == None: + if to_date == None: to_date = D.calendar(freq=freq)[-1] self.to_date = pd.Timestamp(to_date) self.old_pred = record.load_object("pred.pkl") diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 542466a5f..ad7a16218 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -169,7 +169,7 @@ class RollingGen(TaskGen): # First rolling # 1) prepare the end point segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) - test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1] + test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] # 2) and init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 03ba4ed68..ce8e0dfa3 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -90,17 +90,10 @@ class TimeAdjuster: def max(self): """ - (Deprecated) Return the max calendar datetime """ return max(self.cals) - def last_date(self) -> pd.Timestamp: - """ - Return the last datetime in the calendar - """ - return self.cals[-1] - def align_idx(self, time_point, tp_type="start"): """ align the index of time_point in the calendar