mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
* change_url * fix_CI * fix_CI_2 * fix_CI_3 * fix_CI_4 * fix_CI_5 * fix_CI_6 * fix_CI_7 * fix_CI_8 * fix_CI_9 * fix_CI_10 * fix_CI_11 * fix_CI_12 * fix_CI_13 * fix_CI_13 * fix_CI_14 * fix_CI_15 * fix_CI_16 * fix_CI_17 * fix_CI_18 * fix_CI_19 * fix_CI_20 * fix_CI_21 * fix_CI_22 * fix_CI_23 * fix_CI_24 * fix_CI_25 * fix_CI_26 * fix_CI_27 * fix_get_data_error * fix_get_data_error2 * modify_get_data * modify_get_data2 * modify_get_data3 * modify_get_data4 * fix_CI_28 * fix_CI_29 * fix_CI_30 --------- Co-authored-by: Linlang <v-linlanglv@microsoft.com>
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import shutil
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import qlib
|
|
from qlib.data import D
|
|
from qlib.tests.data import GetData
|
|
|
|
DATA_DIR = Path(__file__).parent.joinpath("test_get_data")
|
|
SOURCE_DIR = DATA_DIR.joinpath("source")
|
|
SOURCE_DIR.mkdir(exist_ok=True, parents=True)
|
|
QLIB_DIR = DATA_DIR.joinpath("qlib")
|
|
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
class TestGetData(unittest.TestCase):
|
|
FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
provider_uri = str(QLIB_DIR.resolve())
|
|
qlib.init(
|
|
provider_uri=provider_uri,
|
|
expression_cache=None,
|
|
dataset_cache=None,
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
shutil.rmtree(str(DATA_DIR.resolve()))
|
|
|
|
def test_0_qlib_data(self):
|
|
|
|
GetData().qlib_data(
|
|
name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True
|
|
)
|
|
df = D.features(D.instruments("csi300"), self.FIELDS)
|
|
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
|
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
|
|
|
def test_1_csv_data(self):
|
|
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
|
stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
|
self.assertEqual(len(stock_name), 85, "get csv data failed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|