1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add util function to help automatically get horizon (#1509)

* Add util function to help automatically get horizon

* Reformat for CI

* Leverage horizon change

* Udpate config yaml

* Update for formatting

* Adapt to pickled handler

* Fix CI error

* remove blank

* Fix lint

* Update tests

* Remove redundant check

* modify the code as suggested

* format code with pylint

* fix pytest error

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
This commit is contained in:
Di
2025-05-26 22:08:43 +08:00
committed by GitHub
parent 89ae312109
commit 14d54aa2a1
6 changed files with 43 additions and 4 deletions

View File

@@ -110,7 +110,6 @@ task:
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024

View File

@@ -104,7 +104,6 @@ task:
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024

View File

@@ -104,7 +104,6 @@ task:
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size: 6
num_states: *num_states
batch_size: 1024

View File

@@ -6,6 +6,8 @@ import torch
import warnings
import numpy as np
import pandas as pd
from qlib.utils.data import guess_horizon
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH
@@ -130,6 +132,14 @@ class MTSDatasetH(DatasetH):
input_size=None,
**kwargs,
):
if horizon == 0:
# Try to guess horizon
if isinstance(handler, (dict, str)):
handler = init_instance_by_config(handler)
assert "label" in getattr(handler.data_loader, "fields", None)
label = handler.data_loader.fields["label"][0][0]
horizon = guess_horizon([label])
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"

View File

@@ -5,8 +5,11 @@ This module covers some utility functions that operate on data or basic object
"""
from copy import deepcopy
from typing import List, Union
import pandas as pd
import numpy as np
import pandas as pd
from qlib.data.data import DatasetProvider
def robust_zscore(x: pd.Series, zscore=False):
@@ -103,3 +106,12 @@ def update_config(base_config: dict, ext_config: Union[dict, List[dict]]):
# one of then are not dict. Then replace
base_config[key] = ec[key]
return base_config
def guess_horizon(label: List):
"""
Try to guess the horizon by parsing label
"""
expr = DatasetProvider.parse_fields(label)[0]
lft_etd, rght_etd = expr.get_extended_window_size()
return rght_etd

View File

@@ -9,6 +9,7 @@ from qlib.config import C
from qlib.log import TimeInspector
from qlib.constant import REG_CN, REG_US, REG_TW
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
from qlib.utils.data import guess_horizon
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
@@ -112,5 +113,24 @@ class TimeUtils(TestCase):
cal_sam_minute_new(*args, region=region)
class DataUtils(TestCase):
@classmethod
def setUpClass(cls):
init()
def test_guess_horizon(self):
label = ["Ref($close, -2) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 2
label = ["Ref($close, -5) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 5
label = ["Ref($close, -1) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 1
if __name__ == "__main__":
unittest.main()