mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update tests && fix typo
This commit is contained in:
@@ -41,7 +41,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
@@ -219,4 +219,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
GetData().qlib_data(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
@@ -335,4 +335,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -313,7 +313,7 @@ class YahooCollectorUS(YahooCollector):
|
||||
|
||||
def get_stock_list(self):
|
||||
logger.info("get US stock symbols......")
|
||||
symbols = get_us_stock_symbols(qlib_data_path="/data1/data/yahoo_staock_data/backup/us_1d_qlib") + [
|
||||
symbols = get_us_stock_symbols() + [
|
||||
"^GSPC",
|
||||
"^NDX",
|
||||
"^DJI",
|
||||
|
||||
@@ -287,7 +287,6 @@ class DumpDataAll(DumpDataBase):
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump(self):
|
||||
print("dump 2")
|
||||
self._get_all_date()
|
||||
self._dump_calendars()
|
||||
self._dump_instruments()
|
||||
|
||||
@@ -55,7 +55,7 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", intervel="1d", region="cn"):
|
||||
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
@@ -63,10 +63,10 @@ class GetData:
|
||||
target_dir: str
|
||||
data save directory
|
||||
name: str
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data_us
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
intervel: str
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
@@ -80,7 +80,7 @@ class GetData:
|
||||
# TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system
|
||||
if region.lower() == "us":
|
||||
logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system")
|
||||
file_name = f"{name}_{region.lower()}_{intervel}_{version}.zip"
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestDataset(unittest.TestCase):
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def testCSI300(self):
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestAllFlow(unittest.TestCase):
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
|
||||
GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def test_0_train(self):
|
||||
|
||||
@@ -14,7 +14,7 @@ from qlib.data import D
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
from dump_bin import DumpData
|
||||
from dump_bin import DumpDataAll, DumpDataFix
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_dump_data")
|
||||
@@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
@@ -49,8 +49,10 @@ class TestDumpData(unittest.TestCase):
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(DATA_DIR.resolve()))
|
||||
|
||||
def test_0_dump_calendars(self):
|
||||
self.DUMP_DATA.dump_calendars()
|
||||
def test_0_dump_bin(self):
|
||||
self.DUMP_DATA.dump()
|
||||
|
||||
def test_1_dump_calendars(self):
|
||||
ori_calendars = set(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
@@ -60,23 +62,21 @@ class TestDumpData(unittest.TestCase):
|
||||
res_calendars = set(D.calendar())
|
||||
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"
|
||||
|
||||
def test_1_dump_instruments(self):
|
||||
self.DUMP_DATA.dump_instruments()
|
||||
def test_2_dump_instruments(self):
|
||||
ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
res_ins = set(D.list_instruments(D.instruments("all"), as_list=True))
|
||||
assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed"
|
||||
|
||||
def test_2_dump_features(self):
|
||||
self.DUMP_DATA.dump_features(include_fields=self.FIELDS)
|
||||
def test_3_dump_features(self):
|
||||
df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS)
|
||||
TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :]
|
||||
self.assertFalse(df.dropna().empty, "features data failed")
|
||||
self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed")
|
||||
|
||||
def test_3_dump_features_simple(self):
|
||||
def test_4_dump_features_simple(self):
|
||||
stock = self.STOCK_NAMES[0]
|
||||
dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR)
|
||||
dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt"))
|
||||
dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
|
||||
dump_data.dump()
|
||||
|
||||
df = D.features([stock], self.QLIB_FIELDS)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR)
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
Reference in New Issue
Block a user