1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

refactor && support update bin files

This commit is contained in:
zhupr
2020-11-16 15:49:28 +08:00
committed by you-n-g
parent 87bf5cb01a
commit 77bfeadb65

View File

@@ -1,10 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import shutil
import traceback
from pathlib import Path
from typing import Iterable, List, Union
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
import fire
import numpy as np
@@ -13,8 +16,20 @@ from tqdm import tqdm
from loguru import logger
class DumpData(object):
FILE_SUFFIX = ".csv"
class DumpDataBase:
INSTRUMENTS_START_FIELD = "start_datetime"
INSTRUMENTS_END_FIELD = "end_datetime"
CALENDARS_DIR_NAME = "calendars"
FEATURES_DIR_NAME = "features"
INSTRUMENTS_DIR_NAME = "instruments"
DUMP_FILE_SUFFIX = ".bin"
DAILY_FORMAT = "%Y-%m-%d"
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
UPDATE_MODE = "update"
ALL_MODE = "all"
def __init__(
self,
@@ -22,8 +37,13 @@ class DumpData(object):
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
works: int = None,
max_workers: int = 16,
date_field_name: str = "date",
file_suffix: str = ".csv",
symbol_field_name: str = "symbol",
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
):
"""
@@ -37,80 +57,101 @@ class DumpData(object):
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
works: int, default None
max_workers: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
file_suffix: str, default ".csv"
file suffix
symbol_field_name: str, default "symbol"
symbol field name
include_fields: tuple
dump fields
exclude_fields: tuple
fields not dumped
limit_nums: int
Use when debugging, default None
"""
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
if limit_nums is not None:
self.csv_files = self.csv_files[: int(limit_nums)]
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
self._backup_qlib_dir(Path(backup_dir).expanduser())
self.freq = freq
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT
self.works = works
self.works = max_workers
self.date_field_name = date_field_name
self._calendars_dir = self.qlib_dir.joinpath("calendars")
self._features_dir = self.qlib_dir.joinpath("features")
self._instruments_dir = self.qlib_dir.joinpath("instruments")
self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)
self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)
self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)
self._calendars_list = []
self._include_fields = ()
self._exclude_fields = ()
self._mode = self.ALL_MODE
self._kwargs = {}
def _backup_qlib_dir(self, target_dir: Path):
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
df = pd.read_csv(str(file_path.resolve()))
if df.empty or self.date_field_name not in df.columns.tolist():
return []
if is_begin_end:
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
return df[self.date_field_name].tolist()
def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
datetime_d = pd.Timestamp(datetime_d)
return datetime_d.strftime(self.calendar_format)
def _get_source_data(self, file_path: Path):
df = pd.read_csv(str(file_path.resolve()))
def _get_date(
self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False
) -> Iterable[pd.Timestamp]:
if not isinstance(file_or_df, pd.DataFrame):
df = self._get_source_data(file_or_df)
else:
df = file_or_df
if df.empty or self.date_field_name not in df.columns.tolist():
_calendars = pd.Series()
else:
_calendars = df[self.date_field_name]
if is_begin_end and as_set:
return (_calendars.min(), _calendars.max()), set(_calendars)
elif is_begin_end:
return _calendars.min(), _calendars.max()
elif as_set:
return set(_calendars)
else:
return _calendars.tolist()
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
def _file_to_bin(self, file_path: Path = None):
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
# read csv file
df = self._get_source_data(file_path)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
date_index = self._calendars_list.index(r_df.index.min())
for field in (
def get_symbol_from_file(self, file_path: Path) -> str:
return file_path.name[: -len(self.file_suffix)].strip().lower()
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
self._include_fields
if self._include_fields
else set(r_df.columns) - set(self._exclude_fields)
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else r_df.columns
):
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
if field not in r_df.columns:
continue
r = np.hstack([date_index, r_df[field]]).astype("<f")
r.tofile(str(bin_path.resolve()))
else df_columns
)
@staticmethod
def _read_calendar(calendar_path: Path):
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
return sorted(
map(
pd.Timestamp,
@@ -118,133 +159,303 @@ class DumpData(object):
)
)
def dump_features(
self,
calendar_path: str = None,
include_fields: tuple = None,
exclude_fields: tuple = None,
):
"""dump features
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
return pd.read_csv(
instrument_path,
sep=self.INSTRUMENTS_SEP,
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
)
Parameters
---------
calendar_path: str
calendar path
def save_calendars(self, calendars_data: list):
self._calendars_dir.mkdir(parents=True, exist_ok=True)
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
include_fields: str
dump fields
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
self._instruments_dir.mkdir(parents=True, exist_ok=True)
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
exclude_fields: str
fields not dumped
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
# calendars
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
# align index
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
return r_df
Notes
---------
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
@staticmethod
def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:
return calendar_list.index(df.index.min())
Examples
---------
def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):
if df.empty:
logger.warning(f"{features_dir.name} data is None or empty")
return
# align index
_df = self.data_merge_calendar(df, self._calendars_list)
date_index = self.get_datetime_index(_df, calendar_list)
for field in self.get_dump_fields(_df.columns):
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
if field not in _df.columns:
continue
if self._mode == self.UPDATE_MODE:
# update
with bin_path.open("ab") as fp:
np.array(_df[field]).astype("<f").tofile(fp)
elif self._mode == self.ALL_MODE:
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
else:
raise ValueError(f"{self._mode} cannot support!")
# dump all stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
# dump one stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
"""
logger.info("start dump features......")
if calendar_path is not None:
# read calendar from calendar file
self._calendars_list = self._read_calendar(Path(calendar_path))
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
if isinstance(file_or_data, pd.DataFrame):
if file_or_data.empty:
return
code = file_or_data.iloc[0][self.symbol_field_name].lower()
df = file_or_data
elif isinstance(file_or_data, Path):
code = self.get_symbol_from_file(file_or_data)
df = self._get_source_data(file_or_data)
else:
raise ValueError(f"not support {type(file_or_data)}")
if df is None or df.empty:
logger.warning(f"{code} data is None or empty")
return
# features save dir
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
if not self._calendars_list:
self.dump_calendars()
@abc.abstractmethod
def dump(self):
raise NotImplementedError("dump not implemented!")
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
class DumpDataAll(DumpDataBase):
def _get_all_date(self):
logger.info("start get all date......")
all_datetime = set()
date_range_list = []
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(self._file_to_bin, self.csv_files):
with ProcessPoolExecutor(max_workers=self.works) as executor:
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
self.csv_files, executor.map(_fun, self.csv_files)
):
all_datetime = all_datetime | _set_calendars
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
_begin_time = self._format_datetime(_begin_time)
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
self._kwargs["date_range_list"] = date_range_list
logger.info("end of get all date.\n")
def _dump_calendars(self):
logger.info("start dump calendars......")
self._calendars_list = sorted(map(pd.Timestamp, self._kwargs["all_datetime_set"]))
self.save_calendars(self._calendars_list)
logger.info("end of calendars dump.\n")
def _dump_instruments(self):
logger.info("start dump instruments......")
self.save_instruments(self._kwargs["date_range_list"])
logger.info("end of instruments dump.\n")
def _dump_features(self):
logger.info("start dump features......")
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(_dump_func, self.csv_files):
p_bar.update()
logger.info("end of features dump.\n")
def dump_calendars(self):
"""dump calendars
def dump(self):
print("dump 2")
self._get_all_date()
self._dump_calendars()
self._dump_instruments()
self._dump_features()
Notes
---------
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump calendars......")
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
all_datetime = set()
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
all_datetime = all_datetime | set(temp_datetime)
p_bar.update()
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
self._calendars_dir.mkdir(parents=True, exist_ok=True)
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
logger.info("end of calendars dump.\n")
def dump_instruments(self):
"""dump instruments
Notes
---------
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
class DumpDataFix(DumpDataAll):
def _dump_instruments(self):
logger.info("start dump instruments......")
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
_result_list = []
_fun = partial(self._get_date_for_df, is_begin_end=True)
with tqdm(total=len(symbol_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as execute:
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
if res:
begin_time = res[0]
end_time = res[-1]
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = self.get_symbol_from_file(file_path).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
p_bar.update()
self._instruments_dir.mkdir(parents=True, exist_ok=True)
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
logger.info("end of instruments dump.\n")
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
"""dump data
def dump(self):
self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
# noinspection PyAttributeOutsideInit
self._old_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
) # type: dict
self._dump_instruments()
self._dump_features()
class DumpDataUpdate(DumpDataBase):
def __init__(
self,
csv_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
max_workers: int = 16,
date_field_name: str = "date",
file_suffix: str = ".csv",
symbol_field_name: str = "symbol",
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
):
"""
Parameters
----------
include_fields: str
csv_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
backup_dir: str, default None
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
max_workers: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
file_suffix: str, default ".csv"
file suffix
symbol_field_name: str, default "symbol"
symbol field name
include_fields: tuple
dump fields
exclude_fields: str
exclude_fields: tuple
fields not dumped
Examples
---------
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
limit_nums: int
Use when debugging, default None
"""
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self.dump_calendars()
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
self.dump_instruments()
super().__init__(
csv_path,
qlib_dir,
backup_dir,
freq,
max_workers,
date_field_name,
file_suffix,
symbol_field_name,
exclude_fields,
include_fields,
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
self._update_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
) # type: dict
# load all csv files
self._all_data = self._load_all_source_data() # type: pd.DataFrame
self._update_calendars = sorted(
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
)
self._new_calendar_list = self._old_calendar_list + self._update_calendars
def _load_all_source_data(self):
# NOTE: Need more memory
logger.info("start load all source data....")
all_df = []
def _read_csv(file_path: Path):
if self._include_fields:
_df = pd.read_csv(file_path, usecols=self._include_fields)
else:
_df = pd.read_csv(file_path)
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
if df:
all_df.append(df)
p_bar.update()
logger.info("end of load all data.\n")
return pd.concat(all_df, sort=False)
def _dump_calendars(self):
pass
def _dump_instruments(self):
pass
def _dump_features(self):
logger.info("start dump features......")
error_code = {}
with ProcessPoolExecutor(max_workers=self.works) as executor:
futures = {}
for _code, _df in self._all_data.groupby(self.symbol_field_name):
_code = str(_code).upper()
_start, _end = self._get_date(_df, is_begin_end=True)
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
self._update_instruments[_code]["end_time"] = _end
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())
_dt_range["start_time"] = _start
_dt_range["end_time"] = _end
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
for _future in tqdm(as_completed(futures)):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
logger.info(f"dump bin errors {error_code}")
logger.info("end of features dump.\n")
def dump(self):
self.save_calendars(self._new_calendar_list)
self._dump_features()
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
if __name__ == "__main__":
fire.Fire(DumpData)
fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})