mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* Add util function to help automatically get horizon * Reformat for CI * Leverage horizon change * Udpate config yaml * Update for formatting * Adapt to pickled handler * Fix CI error * remove blank * Fix lint * Update tests * Remove redundant check * modify the code as suggested * format code with pylint * fix pytest error --------- Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
from typing import List
|
|
from unittest.case import TestCase
|
|
import unittest
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from qlib import init
|
|
from qlib.config import C
|
|
from qlib.log import TimeInspector
|
|
from qlib.constant import REG_CN, REG_US, REG_TW
|
|
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
|
|
from qlib.utils.data import guess_horizon
|
|
|
|
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
|
|
|
|
|
|
def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str):
|
|
"""
|
|
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
|
|
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
|
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
|
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
|
|
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
|
"""
|
|
# TODO: actually, this version is much faster when no cache or optimization
|
|
day_time = pd.Timestamp(x.date())
|
|
shift = C.min_data_shift
|
|
region_time = REG_MAP[region]
|
|
|
|
open_time = (
|
|
day_time
|
|
+ pd.Timedelta(hours=region_time[0].hour, minutes=region_time[0].minute)
|
|
- shift * pd.Timedelta(minutes=1)
|
|
)
|
|
close_time = (
|
|
day_time
|
|
+ pd.Timedelta(hours=region_time[-1].hour, minutes=region_time[-1].minute)
|
|
- shift * pd.Timedelta(minutes=1)
|
|
)
|
|
if region_time == CN_TIME:
|
|
mid_close_time = (
|
|
day_time
|
|
+ pd.Timedelta(hours=region_time[1].hour, minutes=region_time[1].minute - 1)
|
|
- shift * pd.Timedelta(minutes=1)
|
|
)
|
|
mid_open_time = (
|
|
day_time
|
|
+ pd.Timedelta(hours=region_time[2].hour, minutes=region_time[2].minute)
|
|
- shift * pd.Timedelta(minutes=1)
|
|
)
|
|
else:
|
|
mid_close_time = close_time
|
|
mid_open_time = open_time
|
|
|
|
if open_time <= x <= mid_close_time:
|
|
minute_index = (x - open_time).seconds // 60
|
|
elif mid_open_time <= x <= close_time:
|
|
minute_index = (x - mid_open_time).seconds // 60 + 120
|
|
else:
|
|
raise ValueError("datetime of calendar is out of range")
|
|
|
|
minute_index = minute_index // sam_minutes * sam_minutes
|
|
|
|
if 0 <= minute_index < 120 or region_time != CN_TIME:
|
|
return open_time + minute_index * pd.Timedelta(minutes=1)
|
|
elif 120 <= minute_index < 240:
|
|
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
|
|
else:
|
|
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
|
|
|
|
|
|
class TimeUtils(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
init()
|
|
|
|
def test_cal_sam_minute(self):
|
|
# test the correctness of the code
|
|
random_n = 1000
|
|
regions = [REG_CN, REG_US, REG_TW]
|
|
|
|
def gen_args(cal: List):
|
|
for time in np.random.choice(cal, size=random_n, replace=True):
|
|
sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
|
|
dt = pd.Timestamp(
|
|
datetime(
|
|
2021,
|
|
month=3,
|
|
day=3,
|
|
hour=time.hour,
|
|
minute=time.minute,
|
|
second=time.second,
|
|
microsecond=time.microsecond,
|
|
)
|
|
)
|
|
args = dt, sam_minutes
|
|
yield args
|
|
|
|
for region in regions:
|
|
cal_time = get_min_cal(region=region)
|
|
for args in gen_args(cal_time):
|
|
assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region)
|
|
|
|
# test the performance of the code
|
|
args_l = list(gen_args(cal_time))
|
|
|
|
with TimeInspector.logt():
|
|
for args in args_l:
|
|
cal_sam_minute(*args, region=region)
|
|
|
|
with TimeInspector.logt():
|
|
for args in args_l:
|
|
cal_sam_minute_new(*args, region=region)
|
|
|
|
|
|
class DataUtils(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
init()
|
|
|
|
def test_guess_horizon(self):
|
|
label = ["Ref($close, -2) / Ref($close, -1) - 1"]
|
|
result = guess_horizon(label)
|
|
assert result == 2
|
|
|
|
label = ["Ref($close, -5) / Ref($close, -1) - 1"]
|
|
result = guess_horizon(label)
|
|
assert result == 5
|
|
|
|
label = ["Ref($close, -1) / Ref($close, -1) - 1"]
|
|
result = guess_horizon(label)
|
|
assert result == 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|