diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index f5e3a2bd0..a7478afb7 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -90,15 +90,15 @@ class OnlineToolR(OnlineTool): The implementation of OnlineTool based on (R)ecorder. """ - def __init__(self, experiment_name: str): + def __init__(self, default_exp_name: str = None): """ Init OnlineToolR. Args: - experiment_name (str): the experiment name. + default_exp_name (str): the default experiment name. """ super().__init__() - self.exp_name = experiment_name + self.default_exp_name = default_exp_name def set_online_tag(self, tag, recorder: Union[Recorder, List]): """ @@ -127,38 +127,61 @@ class OnlineToolR(OnlineTool): tags = recorder.list_tags() return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG) - def reset_online_tag(self, recorder: Union[Recorder, List]): + def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None): """ Offline all models and set the recorders to 'online'. Args: recorder (Union[Recorder, List]): the recorder you want to reset to 'online'. + exp_name (str): the experiment name. If None, then use default_exp_name. """ + if exp_name is None: + if self.default_exp_name is None: + raise ValueError( + "Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment." + ) + exp_name = self.default_exp_name if isinstance(recorder, Recorder): recorder = [recorder] - recs = list_recorders(self.exp_name) + recs = list_recorders(exp_name) self.set_online_tag(self.OFFLINE_TAG, list(recs.values())) self.set_online_tag(self.ONLINE_TAG, recorder) - def online_models(self) -> list: + def online_models(self, exp_name: str = None) -> list: """ Get current `online` models + Args: + exp_name (str): the experiment name. If None, then use default_exp_name. + Returns: list: a list of `online` models. """ - return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) + if exp_name is None: + if self.default_exp_name is None: + raise ValueError( + "Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment." + ) + exp_name = self.default_exp_name + return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) - def update_online_pred(self, to_date=None): + def update_online_pred(self, to_date=None, exp_name: str = None): """ Update the predictions of online models to to_date. Args: to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar. + exp_name (str): the experiment name. If None, then use default_exp_name. """ - online_models = self.online_models() + if exp_name is None: + if self.default_exp_name is None: + raise ValueError( + "Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment." + ) + exp_name = self.default_exp_name + online_models = self.online_models(exp_name=exp_name) for rec in online_models: hist_ref = 0 task = rec.load_object("task") @@ -168,4 +191,4 @@ class OnlineToolR(OnlineTool): hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update() - self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") + self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")