mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add util function to help automatically get horizon (#1509)
* Add util function to help automatically get horizon * Reformat for CI * Leverage horizon change * Udpate config yaml * Update for formatting * Adapt to pickled handler * Fix CI error * remove blank * Fix lint * Update tests * Remove redundant check * modify the code as suggested * format code with pylint * fix pytest error --------- Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
This commit is contained in:
@@ -110,7 +110,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.utils.data import guess_horizon
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
@@ -130,6 +132,14 @@ class MTSDatasetH(DatasetH):
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if horizon == 0:
|
||||
# Try to guess horizon
|
||||
if isinstance(handler, (dict, str)):
|
||||
handler = init_instance_by_config(handler)
|
||||
assert "label" in getattr(handler.data_loader, "fields", None)
|
||||
label = handler.data_loader.fields["label"][0][0]
|
||||
horizon = guess_horizon([label])
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
|
||||
@@ -5,8 +5,11 @@ This module covers some utility functions that operate on data or basic object
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.data import DatasetProvider
|
||||
|
||||
|
||||
def robust_zscore(x: pd.Series, zscore=False):
|
||||
@@ -103,3 +106,12 @@ def update_config(base_config: dict, ext_config: Union[dict, List[dict]]):
|
||||
# one of then are not dict. Then replace
|
||||
base_config[key] = ec[key]
|
||||
return base_config
|
||||
|
||||
|
||||
def guess_horizon(label: List):
|
||||
"""
|
||||
Try to guess the horizon by parsing label
|
||||
"""
|
||||
expr = DatasetProvider.parse_fields(label)[0]
|
||||
lft_etd, rght_etd = expr.get_extended_window_size()
|
||||
return rght_etd
|
||||
|
||||
@@ -9,6 +9,7 @@ from qlib.config import C
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
|
||||
from qlib.utils.data import guess_horizon
|
||||
|
||||
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
|
||||
|
||||
@@ -112,5 +113,24 @@ class TimeUtils(TestCase):
|
||||
cal_sam_minute_new(*args, region=region)
|
||||
|
||||
|
||||
class DataUtils(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
init()
|
||||
|
||||
def test_guess_horizon(self):
|
||||
label = ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 2
|
||||
|
||||
label = ["Ref($close, -5) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 5
|
||||
|
||||
label = ["Ref($close, -1) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user