1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

support subclass of TSDatasetH

This commit is contained in:
Young
2021-06-20 12:00:24 +00:00
committed by you-n-g
parent a3679e6758
commit d0f54343c7
2 changed files with 8 additions and 3 deletions

View File

@@ -507,7 +507,9 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, **kwargs):
DEFAULT_STEP_LEN = 30
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
self.step_len = step_len
super().__init__(**kwargs)

View File

@@ -8,8 +8,10 @@ This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
@@ -161,8 +163,9 @@ class OnlineToolR(OnlineTool):
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")