diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml index 02c4ecac3..35b59086a 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -110,7 +110,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: num_states: *num_states batch_size: 1024 diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml index 9ccf56e86..b837c57a3 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -104,7 +104,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: num_states: *num_states batch_size: 1024 diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml index 29686d7da..43f389739 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -104,7 +104,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: 6 num_states: *num_states batch_size: 1024 diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index 24160d7ba..812e2cc71 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -6,6 +6,8 @@ import torch import warnings import numpy as np import pandas as pd +from qlib.utils.data import guess_horizon +from qlib.utils import init_instance_by_config from qlib.data.dataset import DatasetH @@ -130,6 +132,14 @@ class MTSDatasetH(DatasetH): input_size=None, **kwargs, ): + if horizon == 0: + # Try to guess horizon + if isinstance(handler, (dict, str)): + handler = init_instance_by_config(handler) + assert "label" in getattr(handler.data_loader, "fields", None) + label = handler.data_loader.fields["label"][0][0] + horizon = guess_horizon([label]) + assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage" assert memory_mode in ["sample", "daily"], "unsupported memory mode" assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)" diff --git a/qlib/utils/data.py b/qlib/utils/data.py index 6c62f7558..39634b866 100644 --- a/qlib/utils/data.py +++ b/qlib/utils/data.py @@ -5,8 +5,11 @@ This module covers some utility functions that operate on data or basic object """ from copy import deepcopy from typing import List, Union -import pandas as pd + import numpy as np +import pandas as pd + +from qlib.data.data import DatasetProvider def robust_zscore(x: pd.Series, zscore=False): @@ -103,3 +106,12 @@ def update_config(base_config: dict, ext_config: Union[dict, List[dict]]): # one of then are not dict. Then replace base_config[key] = ec[key] return base_config + + +def guess_horizon(label: List): + """ + Try to guess the horizon by parsing label + """ + expr = DatasetProvider.parse_fields(label)[0] + lft_etd, rght_etd = expr.get_extended_window_size() + return rght_etd diff --git a/tests/misc/test_utils.py b/tests/misc/test_utils.py index 2be792faf..db5b07248 100644 --- a/tests/misc/test_utils.py +++ b/tests/misc/test_utils.py @@ -9,6 +9,7 @@ from qlib.config import C from qlib.log import TimeInspector from qlib.constant import REG_CN, REG_US, REG_TW from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME +from qlib.utils.data import guess_horizon REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME} @@ -112,5 +113,24 @@ class TimeUtils(TestCase): cal_sam_minute_new(*args, region=region) +class DataUtils(TestCase): + @classmethod + def setUpClass(cls): + init() + + def test_guess_horizon(self): + label = ["Ref($close, -2) / Ref($close, -1) - 1"] + result = guess_horizon(label) + assert result == 2 + + label = ["Ref($close, -5) / Ref($close, -1) - 1"] + result = guess_horizon(label) + assert result == 5 + + label = ["Ref($close, -1) / Ref($close, -1) - 1"] + result = guess_horizon(label) + assert result == 1 + + if __name__ == "__main__": unittest.main()