mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
modify OnlineToolR
This commit is contained in:
@@ -90,15 +90,15 @@ class OnlineToolR(OnlineTool):
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
def __init__(self, default_exp_name: str = None):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
default_exp_name (str): the default experiment name.
|
||||
"""
|
||||
super().__init__()
|
||||
self.exp_name = experiment_name
|
||||
self.default_exp_name = default_exp_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
@@ -127,38 +127,61 @@ class OnlineToolR(OnlineTool):
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List]):
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[Recorder, List]):
|
||||
the recorder you want to reset to 'online'.
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
|
||||
"""
|
||||
if exp_name is None:
|
||||
if self.default_exp_name is None:
|
||||
raise ValueError(
|
||||
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
|
||||
)
|
||||
exp_name = self.default_exp_name
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
recs = list_recorders(self.exp_name)
|
||||
recs = list_recorders(exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
|
||||
def online_models(self) -> list:
|
||||
def online_models(self, exp_name: str = None) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Args:
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
if exp_name is None:
|
||||
if self.default_exp_name is None:
|
||||
raise ValueError(
|
||||
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
|
||||
)
|
||||
exp_name = self.default_exp_name
|
||||
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
def update_online_pred(self, to_date=None, exp_name: str = None):
|
||||
"""
|
||||
Update the predictions of online models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
if exp_name is None:
|
||||
if self.default_exp_name is None:
|
||||
raise ValueError(
|
||||
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
|
||||
)
|
||||
exp_name = self.default_exp_name
|
||||
online_models = self.online_models(exp_name=exp_name)
|
||||
for rec in online_models:
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
@@ -168,4 +191,4 @@ class OnlineToolR(OnlineTool):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
|
||||
|
||||
Reference in New Issue
Block a user