1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/misc/test_utils.py
Chia-hung Tai 84e9df2603 Add REG_US and REG_TW into test case: test_utils.py. (#1310)
* Add REG_US and REG_TW into test case: test_utils.py.

* Fix black.

* Trigger checks.

* Add REG_US and REG_TW into test case: test_utils.py.

* Fix black.

* Trigger checks.
2022-10-14 11:49:30 +08:00

117 lines
4.0 KiB
Python

from typing import List
from unittest.case import TestCase
import unittest
import pandas as pd
import numpy as np
from datetime import datetime
from qlib import init
from qlib.config import C
from qlib.log import TimeInspector
from qlib.constant import REG_CN, REG_US, REG_TW
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str):
"""
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
"""
# TODO: actually, this version is much faster when no cache or optimization
day_time = pd.Timestamp(x.date())
shift = C.min_data_shift
region_time = REG_MAP[region]
open_time = (
day_time
+ pd.Timedelta(hours=region_time[0].hour, minutes=region_time[0].minute)
- shift * pd.Timedelta(minutes=1)
)
close_time = (
day_time
+ pd.Timedelta(hours=region_time[-1].hour, minutes=region_time[-1].minute)
- shift * pd.Timedelta(minutes=1)
)
if region_time == CN_TIME:
mid_close_time = (
day_time
+ pd.Timedelta(hours=region_time[1].hour, minutes=region_time[1].minute - 1)
- shift * pd.Timedelta(minutes=1)
)
mid_open_time = (
day_time
+ pd.Timedelta(hours=region_time[2].hour, minutes=region_time[2].minute)
- shift * pd.Timedelta(minutes=1)
)
else:
mid_close_time = close_time
mid_open_time = open_time
if open_time <= x <= mid_close_time:
minute_index = (x - open_time).seconds // 60
elif mid_open_time <= x <= close_time:
minute_index = (x - mid_open_time).seconds // 60 + 120
else:
raise ValueError("datetime of calendar is out of range")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120 or region_time != CN_TIME:
return open_time + minute_index * pd.Timedelta(minutes=1)
elif 120 <= minute_index < 240:
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
else:
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
class TimeUtils(TestCase):
@classmethod
def setUpClass(cls):
init()
def test_cal_sam_minute(self):
# test the correctness of the code
random_n = 1000
regions = [REG_CN, REG_US, REG_TW]
def gen_args(cal: List):
for time in np.random.choice(cal, size=random_n, replace=True):
sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
dt = pd.Timestamp(
datetime(
2021,
month=3,
day=3,
hour=time.hour,
minute=time.minute,
second=time.second,
microsecond=time.microsecond,
)
)
args = dt, sam_minutes
yield args
for region in regions:
cal_time = get_min_cal(region=region)
for args in gen_args(cal_time):
assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region)
# test the performance of the code
args_l = list(gen_args(cal_time))
with TimeInspector.logt():
for args in args_l:
cal_sam_minute(*args, region=region)
with TimeInspector.logt():
for args in args_l:
cal_sam_minute_new(*args, region=region)
if __name__ == "__main__":
unittest.main()