diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py new file mode 100644 index 000000000..7c2ceccda --- /dev/null +++ b/scripts/check_dump_bin.py @@ -0,0 +1,144 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor + +import qlib +from qlib.data import D + +import fire +import datacompy +import pandas as pd +from tqdm import tqdm +from loguru import logger + + +class CheckBin: + + NOT_IN_FEATURES = "not in features" + COMPARE_FALSE = "compare False" + COMPARE_TRUE = "compare True" + COMPARE_ERROR = "compare error" + + def __init__( + self, + qlib_dir: str, + csv_path: str, + check_fields: str = None, + freq: str = "day", + symbol_field_name: str = "symbol", + date_field_name: str = "date", + file_suffix: str = ".csv", + max_workers: int = 16, + ): + """ + + Parameters + ---------- + qlib_dir : str + qlib dir + csv_path : str + origin csv path + check_fields : str, optional + check fields, by default None, check qlib_dir/features//*..bin + freq : str, optional + freq, value from ["day", "1m"] + symbol_field_name: str, optional + symbol field name, by default "symbol" + date_field_name: str, optional + date field name, by default "date" + file_suffix: str, optional + csv file suffix, by default ".csv" + max_workers: int, optional + max workers, by default 16 + """ + self.qlib_dir = Path(qlib_dir).expanduser() + bin_path_list = list(self.qlib_dir.joinpath("features").iterdir()) + self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list)) + qlib.init( + provider_uri=str(self.qlib_dir.resolve()), + mount_path=str(self.qlib_dir.resolve()), + auto_mount=False, + redis_port=-1, + ) + csv_path = Path(csv_path).expanduser() + self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path]) + + if check_fields is None: + check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin"))) + else: + check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields + self.check_fields = list(map(lambda x: x.strip(), check_fields)) + self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields)) + self.max_workers = max_workers + self.symbol_field_name = symbol_field_name + self.date_field_name = date_field_name + self.freq = freq + self.file_suffix = file_suffix + + def _compare(self, file_path: Path): + symbol = file_path.name.strip(self.file_suffix) + if symbol.lower() not in self.qlib_symbols: + return self.NOT_IN_FEATURES + # qlib data + qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq) + qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True) + # csv data + origin_df = pd.read_csv(file_path) + origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name]) + if self.symbol_field_name not in origin_df.columns: + origin_df[self.symbol_field_name] = symbol + origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True) + origin_df.index.names = qlib_df.index.names + try: + compare = datacompy.Compare( + origin_df, + qlib_df, + on_index=True, + abs_tol=1e-08, # Optional, defaults to 0 + rel_tol=1e-05, # Optional, defaults to 0 + df1_name="Original", # Optional, defaults to 'df1' + df2_name="New", # Optional, defaults to 'df2' + ) + _r = compare.matches(ignore_extra_columns=True) + return self.COMPARE_TRUE if _r else self.COMPARE_FALSE + except Exception as e: + logger.warning(f"{symbol} compare error: {e}") + return self.COMPARE_ERROR + + def check(self): + """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data + + """ + logger.info("start check......") + + error_list = [] + not_in_features = [] + compare_false = [] + with tqdm(total=len(self.csv_files)) as p_bar: + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)): + symbol = file_path.name.strip(self.file_suffix) + if _check_res == self.NOT_IN_FEATURES: + not_in_features.append(symbol) + elif _check_res == self.COMPARE_ERROR: + error_list.append(symbol) + elif _check_res == self.COMPARE_FALSE: + compare_false.append(symbol) + p_bar.update() + + logger.info("end of check......") + if error_list: + logger.warning(f"compare error: {error_list}") + if not_in_features: + logger.warning(f"not in features: {not_in_features}") + if compare_false: + logger.warning(f"compare False: {compare_false}") + logger.info( + f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false" + ) + + +if __name__ == "__main__": + fire.Fire(CheckBin)