1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/misc/test_utils.py
Di 14d54aa2a1 Add util function to help automatically get horizon (#1509)
* Add util function to help automatically get horizon

* Reformat for CI

* Leverage horizon change

* Udpate config yaml

* Update for formatting

* Adapt to pickled handler

* Fix CI error

* remove blank

* Fix lint

* Update tests

* Remove redundant check

* modify the code as suggested

* format code with pylint

* fix pytest error

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2025-05-26 22:08:43 +08:00

137 lines
4.6 KiB
Python

from typing import List
from unittest.case import TestCase
import unittest
import pandas as pd
import numpy as np
from datetime import datetime
from qlib import init
from qlib.config import C
from qlib.log import TimeInspector
from qlib.constant import REG_CN, REG_US, REG_TW
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
from qlib.utils.data import guess_horizon
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str):
"""
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
"""
# TODO: actually, this version is much faster when no cache or optimization
day_time = pd.Timestamp(x.date())
shift = C.min_data_shift
region_time = REG_MAP[region]
open_time = (
day_time
+ pd.Timedelta(hours=region_time[0].hour, minutes=region_time[0].minute)
- shift * pd.Timedelta(minutes=1)
)
close_time = (
day_time
+ pd.Timedelta(hours=region_time[-1].hour, minutes=region_time[-1].minute)
- shift * pd.Timedelta(minutes=1)
)
if region_time == CN_TIME:
mid_close_time = (
day_time
+ pd.Timedelta(hours=region_time[1].hour, minutes=region_time[1].minute - 1)
- shift * pd.Timedelta(minutes=1)
)
mid_open_time = (
day_time
+ pd.Timedelta(hours=region_time[2].hour, minutes=region_time[2].minute)
- shift * pd.Timedelta(minutes=1)
)
else:
mid_close_time = close_time
mid_open_time = open_time
if open_time <= x <= mid_close_time:
minute_index = (x - open_time).seconds // 60
elif mid_open_time <= x <= close_time:
minute_index = (x - mid_open_time).seconds // 60 + 120
else:
raise ValueError("datetime of calendar is out of range")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120 or region_time != CN_TIME:
return open_time + minute_index * pd.Timedelta(minutes=1)
elif 120 <= minute_index < 240:
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
else:
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
class TimeUtils(TestCase):
@classmethod
def setUpClass(cls):
init()
def test_cal_sam_minute(self):
# test the correctness of the code
random_n = 1000
regions = [REG_CN, REG_US, REG_TW]
def gen_args(cal: List):
for time in np.random.choice(cal, size=random_n, replace=True):
sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
dt = pd.Timestamp(
datetime(
2021,
month=3,
day=3,
hour=time.hour,
minute=time.minute,
second=time.second,
microsecond=time.microsecond,
)
)
args = dt, sam_minutes
yield args
for region in regions:
cal_time = get_min_cal(region=region)
for args in gen_args(cal_time):
assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region)
# test the performance of the code
args_l = list(gen_args(cal_time))
with TimeInspector.logt():
for args in args_l:
cal_sam_minute(*args, region=region)
with TimeInspector.logt():
for args in args_l:
cal_sam_minute_new(*args, region=region)
class DataUtils(TestCase):
@classmethod
def setUpClass(cls):
init()
def test_guess_horizon(self):
label = ["Ref($close, -2) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 2
label = ["Ref($close, -5) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 5
label = ["Ref($close, -1) / Ref($close, -1) - 1"]
result = guess_horizon(label)
assert result == 1
if __name__ == "__main__":
unittest.main()