From 77bfeadb65a250f59a2d6c5b119bac40c4e9a533 Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 16 Nov 2020 15:49:28 +0800 Subject: [PATCH] refactor && support update bin files --- scripts/dump_bin.py | 513 +++++++++++++++++++++++++++++++------------- 1 file changed, 362 insertions(+), 151 deletions(-) diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index d972f6318..94e970808 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import abc import shutil +import traceback from pathlib import Path +from typing import Iterable, List, Union from functools import partial -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor import fire import numpy as np @@ -13,8 +16,20 @@ from tqdm import tqdm from loguru import logger -class DumpData(object): - FILE_SUFFIX = ".csv" +class DumpDataBase: + INSTRUMENTS_START_FIELD = "start_datetime" + INSTRUMENTS_END_FIELD = "end_datetime" + CALENDARS_DIR_NAME = "calendars" + FEATURES_DIR_NAME = "features" + INSTRUMENTS_DIR_NAME = "instruments" + DUMP_FILE_SUFFIX = ".bin" + DAILY_FORMAT = "%Y-%m-%d" + HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S" + INSTRUMENTS_SEP = "\t" + INSTRUMENTS_FILE_NAME = "all.txt" + + UPDATE_MODE = "update" + ALL_MODE = "all" def __init__( self, @@ -22,8 +37,13 @@ class DumpData(object): qlib_dir: str, backup_dir: str = None, freq: str = "day", - works: int = None, + max_workers: int = 16, date_field_name: str = "date", + file_suffix: str = ".csv", + symbol_field_name: str = "symbol", + exclude_fields: str = "", + include_fields: str = "", + limit_nums: int = None, ): """ @@ -37,80 +57,101 @@ class DumpData(object): if backup_dir is not None, backup qlib_dir to backup_dir freq: str, default "day" transaction frequency - works: int, default None + max_workers: int, default None number of threads date_field_name: str, default "date" the name of the date field in the csv + file_suffix: str, default ".csv" + file suffix + symbol_field_name: str, default "symbol" + symbol field name + include_fields: tuple + dump fields + exclude_fields: tuple + fields not dumped + limit_nums: int + Use when debugging, default None """ csv_path = Path(csv_path).expanduser() - self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path]) + if isinstance(exclude_fields, str): + exclude_fields = exclude_fields.split(",") + if isinstance(include_fields, str): + include_fields = include_fields.split(",") + self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) + self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self.file_suffix = file_suffix + self.symbol_field_name = symbol_field_name + self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) + if limit_nums is not None: + self.csv_files = self.csv_files[: int(limit_nums)] self.qlib_dir = Path(qlib_dir).expanduser() self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() if backup_dir is not None: self._backup_qlib_dir(Path(backup_dir).expanduser()) self.freq = freq - self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S" + self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT - self.works = works + self.works = max_workers self.date_field_name = date_field_name - self._calendars_dir = self.qlib_dir.joinpath("calendars") - self._features_dir = self.qlib_dir.joinpath("features") - self._instruments_dir = self.qlib_dir.joinpath("instruments") + self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME) + self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME) + self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME) self._calendars_list = [] - self._include_fields = () - self._exclude_fields = () + + self._mode = self.ALL_MODE + self._kwargs = {} def _backup_qlib_dir(self, target_dir: Path): shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve())) - def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False): - df = pd.read_csv(str(file_path.resolve())) - if df.empty or self.date_field_name not in df.columns.tolist(): - return [] - if is_begin_end: - return [df[self.date_field_name].min(), df[self.date_field_name].max()] - return df[self.date_field_name].tolist() + def _format_datetime(self, datetime_d: [str, pd.Timestamp]): + datetime_d = pd.Timestamp(datetime_d) + return datetime_d.strftime(self.calendar_format) - def _get_source_data(self, file_path: Path): - df = pd.read_csv(str(file_path.resolve())) + def _get_date( + self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False + ) -> Iterable[pd.Timestamp]: + if not isinstance(file_or_df, pd.DataFrame): + df = self._get_source_data(file_or_df) + else: + df = file_or_df + if df.empty or self.date_field_name not in df.columns.tolist(): + _calendars = pd.Series() + else: + _calendars = df[self.date_field_name] + + if is_begin_end and as_set: + return (_calendars.min(), _calendars.max()), set(_calendars) + elif is_begin_end: + return _calendars.min(), _calendars.max() + elif as_set: + return set(_calendars) + else: + return _calendars.tolist() + + def _get_source_data(self, file_path: Path) -> pd.DataFrame: + df = pd.read_csv(str(file_path.resolve()), low_memory=False) df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64) + # df.drop_duplicates([self.date_field_name], inplace=True) return df - def _file_to_bin(self, file_path: Path = None): - code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower() - features_dir = self._features_dir.joinpath(code) - features_dir.mkdir(parents=True, exist_ok=True) - calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name]) - calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64) - # read csv file - df = self._get_source_data(file_path) - cal_df = calendars_df[ - (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) - & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) - ] - cal_df.set_index(self.date_field_name, inplace=True) - df.set_index(self.date_field_name, inplace=True) - r_df = df.reindex(cal_df.index) - date_index = self._calendars_list.index(r_df.index.min()) - for field in ( + def get_symbol_from_file(self, file_path: Path) -> str: + return file_path.name[: -len(self.file_suffix)].strip().lower() + + def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: + return ( self._include_fields if self._include_fields - else set(r_df.columns) - set(self._exclude_fields) + else set(df_columns) - set(self._exclude_fields) if self._exclude_fields - else r_df.columns - ): - - bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin") - if field not in r_df.columns: - continue - r = np.hstack([date_index, r_df[field]]).astype(" List[pd.Timestamp]: return sorted( map( pd.Timestamp, @@ -118,133 +159,303 @@ class DumpData(object): ) ) - def dump_features( - self, - calendar_path: str = None, - include_fields: tuple = None, - exclude_fields: tuple = None, - ): - """dump features + def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: + return pd.read_csv( + instrument_path, + sep=self.INSTRUMENTS_SEP, + names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD], + ) - Parameters - --------- - calendar_path: str - calendar path + def save_calendars(self, calendars_data: list): + self._calendars_dir.mkdir(parents=True, exist_ok=True) + calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) + result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data)) + np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8") - include_fields: str - dump fields + def save_instruments(self, instruments_data: Union[list, pd.DataFrame]): + self._instruments_dir.mkdir(parents=True, exist_ok=True) + instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve()) + if isinstance(instruments_data, pd.DataFrame): + instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]] + instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP) + else: + np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8") - exclude_fields: str - fields not dumped + def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame: + # calendars + calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name]) + calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64) + cal_df = calendars_df[ + (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) + & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) + ] + # align index + cal_df.set_index(self.date_field_name, inplace=True) + df.set_index(self.date_field_name, inplace=True) + r_df = df.reindex(cal_df.index) + return r_df - Notes - --------- - python dump_bin.py dump_features --csv_path --qlib_dir + @staticmethod + def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int: + return calendar_list.index(df.index.min()) - Examples - --------- + def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path): + if df.empty: + logger.warning(f"{features_dir.name} data is None or empty") + return + # align index + _df = self.data_merge_calendar(df, self._calendars_list) + date_index = self.get_datetime_index(_df, calendar_list) + for field in self.get_dump_fields(_df.columns): + bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}") + if field not in _df.columns: + continue + if self._mode == self.UPDATE_MODE: + # update + with bin_path.open("ab") as fp: + np.array(_df[field]).astype(" --qlib_dir - Examples - --------- - python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data - """ - logger.info("start dump calendars......") - calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) - all_datetime = set() - with tqdm(total=len(self.csv_files)) as p_bar: - with ThreadPoolExecutor(max_workers=self.works) as executor: - for temp_datetime in executor.map(self._get_date_for_df, self.csv_files): - all_datetime = all_datetime | set(temp_datetime) - p_bar.update() - - self._calendars_list = sorted(map(pd.Timestamp, all_datetime)) - self._calendars_dir.mkdir(parents=True, exist_ok=True) - result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list)) - np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8") - logger.info("end of calendars dump.\n") - - def dump_instruments(self): - """dump instruments - - Notes - --------- - python dump_bin.py dump_instruments --csv_path --qlib_dir - - Examples - --------- - python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data - """ +class DumpDataFix(DumpDataAll): + def _dump_instruments(self): logger.info("start dump instruments......") - symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files)) - _result_list = [] - _fun = partial(self._get_date_for_df, is_begin_end=True) - with tqdm(total=len(symbol_list)) as p_bar: - with ThreadPoolExecutor(max_workers=self.works) as execute: - for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)): - if res: - begin_time = res[0] - end_time = res[-1] - _result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}") + _fun = partial(self._get_date, is_begin_end=True) + new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files)) + with tqdm(total=len(new_stock_files)) as p_bar: + with ProcessPoolExecutor(max_workers=self.works) as execute: + for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)): + if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp): + symbol = self.get_symbol_from_file(file_path).upper() + _dt_map = self._old_instruments.setdefault(symbol, dict()) + _dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time) + _dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time) p_bar.update() - - self._instruments_dir.mkdir(parents=True, exist_ok=True) - to_path = str(self._instruments_dir.joinpath("all.txt").resolve()) - np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8") + self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index")) logger.info("end of instruments dump.\n") - def dump(self, include_fields: str = None, exclude_fields: tuple = None): - """dump data + def dump(self): + self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt")) + # noinspection PyAttributeOutsideInit + self._old_instruments = self._read_instruments( + self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME) + ).to_dict( + orient="index" + ) # type: dict + self._dump_instruments() + self._dump_features() + + +class DumpDataUpdate(DumpDataBase): + def __init__( + self, + csv_path: str, + qlib_dir: str, + backup_dir: str = None, + freq: str = "day", + max_workers: int = 16, + date_field_name: str = "date", + file_suffix: str = ".csv", + symbol_field_name: str = "symbol", + exclude_fields: str = "", + include_fields: str = "", + limit_nums: int = None, + ): + """ Parameters ---------- - include_fields: str + csv_path: str + stock data path or directory + qlib_dir: str + qlib(dump) data director + backup_dir: str, default None + if backup_dir is not None, backup qlib_dir to backup_dir + freq: str, default "day" + transaction frequency + max_workers: int, default None + number of threads + date_field_name: str, default "date" + the name of the date field in the csv + file_suffix: str, default ".csv" + file suffix + symbol_field_name: str, default "symbol" + symbol field name + include_fields: tuple dump fields - - exclude_fields: str + exclude_fields: tuple fields not dumped - - Examples - --------- - python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor - python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name + limit_nums: int + Use when debugging, default None """ - if isinstance(exclude_fields, str): - exclude_fields = exclude_fields.split(",") - if isinstance(include_fields, str): - include_fields = include_fields.split(",") - self.dump_calendars() - self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields) - self.dump_instruments() + super().__init__( + csv_path, + qlib_dir, + backup_dir, + freq, + max_workers, + date_field_name, + file_suffix, + symbol_field_name, + exclude_fields, + include_fields, + ) + self._mode = self.UPDATE_MODE + self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt")) + self._update_instruments = self._read_instruments( + self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME) + ).to_dict( + orient="index" + ) # type: dict + + # load all csv files + self._all_data = self._load_all_source_data() # type: pd.DataFrame + self._update_calendars = sorted( + filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique()) + ) + self._new_calendar_list = self._old_calendar_list + self._update_calendars + + def _load_all_source_data(self): + # NOTE: Need more memory + logger.info("start load all source data....") + all_df = [] + + def _read_csv(file_path: Path): + if self._include_fields: + _df = pd.read_csv(file_path, usecols=self._include_fields) + else: + _df = pd.read_csv(file_path) + if self.symbol_field_name not in _df.columns: + _df[self.symbol_field_name] = self.get_symbol_from_file(file_path) + return _df + + with tqdm(total=len(self.csv_files)) as p_bar: + with ThreadPoolExecutor(max_workers=self.works) as executor: + for df in executor.map(_read_csv, self.csv_files): + if df: + all_df.append(df) + p_bar.update() + + logger.info("end of load all data.\n") + return pd.concat(all_df, sort=False) + + def _dump_calendars(self): + pass + + def _dump_instruments(self): + pass + + def _dump_features(self): + logger.info("start dump features......") + error_code = {} + with ProcessPoolExecutor(max_workers=self.works) as executor: + futures = {} + for _code, _df in self._all_data.groupby(self.symbol_field_name): + _code = str(_code).upper() + _start, _end = self._get_date(_df, is_begin_end=True) + if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)): + continue + if _code in self._update_instruments: + self._update_instruments[_code]["end_time"] = _end + futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code + else: + # new stock + _dt_range = self._update_instruments.setdefault(_code, dict()) + _dt_range["start_time"] = _start + _dt_range["end_time"] = _end + futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code + + for _future in tqdm(as_completed(futures)): + try: + _future.result() + except Exception: + error_code[futures[_future]] = traceback.format_exc() + logger.info(f"dump bin errors: {error_code}") + + logger.info("end of features dump.\n") + + def dump(self): + self.save_calendars(self._new_calendar_list) + self._dump_features() + self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index")) if __name__ == "__main__": - fire.Fire(DumpData) + fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})