mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
145 lines
5.4 KiB
Python
145 lines
5.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from pathlib import Path
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
import qlib
|
|
from qlib.data import D
|
|
|
|
import fire
|
|
import datacompy
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
|
|
|
|
class CheckBin:
|
|
|
|
NOT_IN_FEATURES = "not in features"
|
|
COMPARE_FALSE = "compare False"
|
|
COMPARE_TRUE = "compare True"
|
|
COMPARE_ERROR = "compare error"
|
|
|
|
def __init__(
|
|
self,
|
|
qlib_dir: str,
|
|
csv_path: str,
|
|
check_fields: str = None,
|
|
freq: str = "day",
|
|
symbol_field_name: str = "symbol",
|
|
date_field_name: str = "date",
|
|
file_suffix: str = ".csv",
|
|
max_workers: int = 16,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
qlib_dir : str
|
|
qlib dir
|
|
csv_path : str
|
|
origin csv path
|
|
check_fields : str, optional
|
|
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
|
|
freq : str, optional
|
|
freq, value from ["day", "1m"]
|
|
symbol_field_name: str, optional
|
|
symbol field name, by default "symbol"
|
|
date_field_name: str, optional
|
|
date field name, by default "date"
|
|
file_suffix: str, optional
|
|
csv file suffix, by default ".csv"
|
|
max_workers: int, optional
|
|
max workers, by default 16
|
|
"""
|
|
self.qlib_dir = Path(qlib_dir).expanduser()
|
|
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
|
|
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
|
|
qlib.init(
|
|
provider_uri=str(self.qlib_dir.resolve()),
|
|
mount_path=str(self.qlib_dir.resolve()),
|
|
auto_mount=False,
|
|
redis_port=-1,
|
|
)
|
|
csv_path = Path(csv_path).expanduser()
|
|
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
|
|
|
if check_fields is None:
|
|
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
|
else:
|
|
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
|
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
|
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
|
|
self.max_workers = max_workers
|
|
self.symbol_field_name = symbol_field_name
|
|
self.date_field_name = date_field_name
|
|
self.freq = freq
|
|
self.file_suffix = file_suffix
|
|
|
|
def _compare(self, file_path: Path):
|
|
symbol = file_path.name.strip(self.file_suffix)
|
|
if symbol.lower() not in self.qlib_symbols:
|
|
return self.NOT_IN_FEATURES
|
|
# qlib data
|
|
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
|
|
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
|
|
# csv data
|
|
origin_df = pd.read_csv(file_path)
|
|
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
|
|
if self.symbol_field_name not in origin_df.columns:
|
|
origin_df[self.symbol_field_name] = symbol
|
|
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
|
origin_df.index.names = qlib_df.index.names
|
|
try:
|
|
compare = datacompy.Compare(
|
|
origin_df,
|
|
qlib_df,
|
|
on_index=True,
|
|
abs_tol=1e-08, # Optional, defaults to 0
|
|
rel_tol=1e-05, # Optional, defaults to 0
|
|
df1_name="Original", # Optional, defaults to 'df1'
|
|
df2_name="New", # Optional, defaults to 'df2'
|
|
)
|
|
_r = compare.matches(ignore_extra_columns=True)
|
|
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
|
|
except Exception as e:
|
|
logger.warning(f"{symbol} compare error: {e}")
|
|
return self.COMPARE_ERROR
|
|
|
|
def check(self):
|
|
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
|
|
|
|
"""
|
|
logger.info("start check......")
|
|
|
|
error_list = []
|
|
not_in_features = []
|
|
compare_false = []
|
|
with tqdm(total=len(self.csv_files)) as p_bar:
|
|
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
|
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
|
|
symbol = file_path.name.strip(self.file_suffix)
|
|
if _check_res == self.NOT_IN_FEATURES:
|
|
not_in_features.append(symbol)
|
|
elif _check_res == self.COMPARE_ERROR:
|
|
error_list.append(symbol)
|
|
elif _check_res == self.COMPARE_FALSE:
|
|
compare_false.append(symbol)
|
|
p_bar.update()
|
|
|
|
logger.info("end of check......")
|
|
if error_list:
|
|
logger.warning(f"compare error: {error_list}")
|
|
if not_in_features:
|
|
logger.warning(f"not in features: {not_in_features}")
|
|
if compare_false:
|
|
logger.warning(f"compare False: {compare_false}")
|
|
logger.info(
|
|
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(CheckBin)
|