1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/scripts/dump_bin.py
2020-09-22 01:43:21 +00:00

251 lines
9.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import shutil
from pathlib import Path
from functools import partial
from concurrent.futures import ThreadPoolExecutor
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
class DumpData(object):
FILE_SUFFIX = ".csv"
def __init__(
self,
csv_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
works: int = None,
date_field_name: str = "date",
):
"""
Parameters
----------
csv_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
backup_dir: str, default None
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
works: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
"""
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
self._backup_qlib_dir(Path(backup_dir).expanduser())
self.freq = freq
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
self.works = works
self.date_field_name = date_field_name
self._calendars_dir = self.qlib_dir.joinpath("calendars")
self._features_dir = self.qlib_dir.joinpath("features")
self._instruments_dir = self.qlib_dir.joinpath("instruments")
self._calendars_list = []
self._include_fields = ()
self._exclude_fields = ()
def _backup_qlib_dir(self, target_dir: Path):
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
df = pd.read_csv(str(file_path.resolve()))
if df.empty or self.date_field_name not in df.columns.tolist():
return []
if is_begin_end:
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
return df[self.date_field_name].tolist()
def _get_source_data(self, file_path: Path):
df = pd.read_csv(str(file_path.resolve()))
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
return df
def _file_to_bin(self, file_path: Path = None):
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
# read csv file
df = self._get_source_data(file_path)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
date_index = self._calendars_list.index(r_df.index.min())
for field in (
self._include_fields
if self._include_fields
else set(r_df.columns) - set(self._exclude_fields)
if self._exclude_fields
else r_df.columns
):
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
if field not in r_df.columns:
continue
r = np.hstack([date_index, r_df[field]]).astype("<f")
r.tofile(str(bin_path.resolve()))
@staticmethod
def _read_calendar(calendar_path: Path):
return sorted(
map(
pd.Timestamp,
pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
)
)
def dump_features(
self,
calendar_path: str = None,
include_fields: tuple = None,
exclude_fields: tuple = None,
):
"""dump features
Parameters
---------
calendar_path: str
calendar path
include_fields: str
dump fields
exclude_fields: str
fields not dumped
Notes
---------
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
# dump all stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
# dump one stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
"""
logger.info("start dump features......")
if calendar_path is not None:
# read calendar from calendar file
self._calendars_list = self._read_calendar(Path(calendar_path))
if not self._calendars_list:
self.dump_calendars()
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(self._file_to_bin, self.csv_files):
p_bar.update()
logger.info("end of features dump.\n")
def dump_calendars(self):
"""dump calendars
Notes
---------
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump calendars......")
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
all_datetime = set()
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
all_datetime = all_datetime | set(temp_datetime)
p_bar.update()
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
self._calendars_dir.mkdir(parents=True, exist_ok=True)
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
logger.info("end of calendars dump.\n")
def dump_instruments(self):
"""dump instruments
Notes
---------
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump instruments......")
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
_result_list = []
_fun = partial(self._get_date_for_df, is_begin_end=True)
with tqdm(total=len(symbol_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as execute:
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
if res:
begin_time = res[0]
end_time = res[-1]
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
p_bar.update()
self._instruments_dir.mkdir(parents=True, exist_ok=True)
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
logger.info("end of instruments dump.\n")
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
"""dump data
Parameters
----------
include_fields: str
dump fields
exclude_fields: str
fields not dumped
Examples
---------
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
"""
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self.dump_calendars()
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
self.dump_instruments()
if __name__ == "__main__":
fire.Fire(DumpData)