mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add script to verify that bin data and csv data are consistent
This commit is contained in:
144
scripts/check_dump_bin.py
Normal file
144
scripts/check_dump_bin.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
import fire
|
||||
import datacompy
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CheckBin:
|
||||
|
||||
NOT_IN_FEATURES = "not in features"
|
||||
COMPARE_FALSE = "compare False"
|
||||
COMPARE_TRUE = "compare True"
|
||||
COMPARE_ERROR = "compare error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qlib_dir: str,
|
||||
csv_path: str,
|
||||
check_fields: str = None,
|
||||
freq: str = "day",
|
||||
symbol_field_name: str = "symbol",
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
max_workers: int = 16,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir : str
|
||||
qlib dir
|
||||
csv_path : str
|
||||
origin csv path
|
||||
check_fields : str, optional
|
||||
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
|
||||
freq : str, optional
|
||||
freq, value from ["day", "1m"]
|
||||
symbol_field_name: str, optional
|
||||
symbol field name, by default "symbol"
|
||||
date_field_name: str, optional
|
||||
date field name, by default "date"
|
||||
file_suffix: str, optional
|
||||
csv file suffix, by default ".csv"
|
||||
max_workers: int, optional
|
||||
max workers, by default 16
|
||||
"""
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
|
||||
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
|
||||
qlib.init(
|
||||
provider_uri=str(self.qlib_dir.resolve()),
|
||||
mount_path=str(self.qlib_dir.resolve()),
|
||||
auto_mount=False,
|
||||
redis_port=-1,
|
||||
)
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
|
||||
if check_fields is None:
|
||||
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
else:
|
||||
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
||||
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
||||
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
|
||||
self.max_workers = max_workers
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.date_field_name = date_field_name
|
||||
self.freq = freq
|
||||
self.file_suffix = file_suffix
|
||||
|
||||
def _compare(self, file_path: Path):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if symbol.lower() not in self.qlib_symbols:
|
||||
return self.NOT_IN_FEATURES
|
||||
# qlib data
|
||||
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
|
||||
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
|
||||
# csv data
|
||||
origin_df = pd.read_csv(file_path)
|
||||
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
|
||||
if self.symbol_field_name not in origin_df.columns:
|
||||
origin_df[self.symbol_field_name] = symbol
|
||||
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
||||
origin_df.index.names = qlib_df.index.names
|
||||
try:
|
||||
compare = datacompy.Compare(
|
||||
origin_df,
|
||||
qlib_df,
|
||||
on_index=True,
|
||||
abs_tol=1e-08, # Optional, defaults to 0
|
||||
rel_tol=1e-05, # Optional, defaults to 0
|
||||
df1_name="Original", # Optional, defaults to 'df1'
|
||||
df2_name="New", # Optional, defaults to 'df2'
|
||||
)
|
||||
_r = compare.matches(ignore_extra_columns=True)
|
||||
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol} compare error: {e}")
|
||||
return self.COMPARE_ERROR
|
||||
|
||||
def check(self):
|
||||
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
|
||||
|
||||
"""
|
||||
logger.info("start check......")
|
||||
|
||||
error_list = []
|
||||
not_in_features = []
|
||||
compare_false = []
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if _check_res == self.NOT_IN_FEATURES:
|
||||
not_in_features.append(symbol)
|
||||
elif _check_res == self.COMPARE_ERROR:
|
||||
error_list.append(symbol)
|
||||
elif _check_res == self.COMPARE_FALSE:
|
||||
compare_false.append(symbol)
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of check......")
|
||||
if error_list:
|
||||
logger.warning(f"compare error: {error_list}")
|
||||
if not_in_features:
|
||||
logger.warning(f"not in features: {not_in_features}")
|
||||
if compare_false:
|
||||
logger.warning(f"compare False: {compare_false}")
|
||||
logger.info(
|
||||
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(CheckBin)
|
||||
Reference in New Issue
Block a user