1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

modify OnlineToolR

This commit is contained in:
lzh222333
2021-06-22 02:50:41 +00:00
committed by you-n-g
parent 8709dde65b
commit c4c438249c

View File

@@ -90,15 +90,15 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str):
def __init__(self, default_exp_name: str = None):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
default_exp_name (str): the default experiment name.
"""
super().__init__()
self.exp_name = experiment_name
self.default_exp_name = default_exp_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
@@ -127,38 +127,61 @@ class OnlineToolR(OnlineTool):
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
recs = list_recorders(exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
def online_models(self, exp_name: str = None) -> list:
"""
Get current `online` models
Args:
exp_name (str): the experiment name. If None, then use default_exp_name.
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
def update_online_pred(self, to_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
online_models = self.online_models()
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
@@ -168,4 +191,4 @@ class OnlineToolR(OnlineTool):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")