diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index fe641be35..2acaa77fe 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -507,7 +507,9 @@ class TSDatasetH(DatasetH): - The dimension of a batch of data """ - def __init__(self, step_len=30, **kwargs): + DEFAULT_STEP_LEN = 30 + + def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs): self.step_len = step_len super().__init__(**kwargs) diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index f3ef13aa9..f5e3a2bd0 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -8,8 +8,10 @@ This allows us to use efficient submodels as the market-style changing. """ from typing import List, Union +from qlib.data.dataset import TSDatasetH from qlib.log import get_module_logger +from qlib.utils import get_cls_kwargs from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder from qlib.workflow.task.utils import list_recorders @@ -161,8 +163,9 @@ class OnlineToolR(OnlineTool): hist_ref = 0 task = rec.load_object("task") # Special treatment of historical dependencies - if task["dataset"]["class"] == "TSDatasetH": - hist_ref = task["dataset"]["kwargs"]["step_len"] + cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset") + if issubclass(cls, TSDatasetH): + hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update() self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")