mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
54 lines
2.3 KiB
Python
54 lines
2.3 KiB
Python
import sys
|
|
from pathlib import Path
|
|
import qlib
|
|
from qlib.data import D
|
|
from qlib.config import REG_CN
|
|
import unittest
|
|
import numpy as np
|
|
from qlib.utils import exists_qlib_data
|
|
|
|
|
|
class TestDataset(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
# use default data
|
|
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
|
if not exists_qlib_data(provider_uri):
|
|
print(f"Qlib data is not found in {provider_uri}")
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
|
from get_data import GetData
|
|
|
|
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
|
|
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
|
|
|
def testCSI300(self):
|
|
close_p = D.features(D.instruments("csi300"), ["$close"])
|
|
size = close_p.groupby("datetime").size()
|
|
cnt = close_p.groupby("datetime").count()["$close"]
|
|
size_desc = size.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
|
cnt_desc = cnt.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
|
|
|
print(size_desc)
|
|
print(cnt_desc)
|
|
|
|
self.assertLessEqual(size_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks")
|
|
self.assertGreaterEqual(size_desc.loc["80%"], 290, "Insufficient number of CSI300 constituent stocks")
|
|
|
|
self.assertLessEqual(cnt_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks")
|
|
# FIXME: Due to the low quality of data. Hard to make sure there are enough data
|
|
# self.assertEqual(cnt_desc.loc["80%"], 300, "Insufficient number of CSI300 constituent stocks")
|
|
|
|
def testClose(self):
|
|
close_p = D.features(D.instruments("csi300"), ["Ref($close, 1)/$close - 1"])
|
|
close_desc = close_p.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
|
print(close_desc)
|
|
self.assertLessEqual(abs(close_desc.loc["90%"][0]), 0.1, "Close value is abnormal")
|
|
self.assertLessEqual(abs(close_desc.loc["10%"][0]), 0.1, "Close value is abnormal")
|
|
# FIXME: The yahoo data is not perfect. We have to
|
|
# self.assertLessEqual(abs(close_desc.loc["max"][0]), 0.2, "Close value is abnormal")
|
|
# self.assertGreaterEqual(close_desc.loc["min"][0], -0.2, "Close value is abnormal")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|