mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
refactor && support update bin files
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
@@ -13,8 +16,20 @@ from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DumpData(object):
|
||||
FILE_SUFFIX = ".csv"
|
||||
class DumpDataBase:
|
||||
INSTRUMENTS_START_FIELD = "start_datetime"
|
||||
INSTRUMENTS_END_FIELD = "end_datetime"
|
||||
CALENDARS_DIR_NAME = "calendars"
|
||||
FEATURES_DIR_NAME = "features"
|
||||
INSTRUMENTS_DIR_NAME = "instruments"
|
||||
DUMP_FILE_SUFFIX = ".bin"
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -22,8 +37,13 @@ class DumpData(object):
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
works: int = None,
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
symbol_field_name: str = "symbol",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -37,80 +57,101 @@ class DumpData(object):
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "day"
|
||||
transaction frequency
|
||||
works: int, default None
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_field_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
symbol_field_name: str, default "symbol"
|
||||
symbol field name
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
self._backup_qlib_dir(Path(backup_dir).expanduser())
|
||||
|
||||
self.freq = freq
|
||||
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
|
||||
self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT
|
||||
|
||||
self.works = works
|
||||
self.works = max_workers
|
||||
self.date_field_name = date_field_name
|
||||
|
||||
self._calendars_dir = self.qlib_dir.joinpath("calendars")
|
||||
self._features_dir = self.qlib_dir.joinpath("features")
|
||||
self._instruments_dir = self.qlib_dir.joinpath("instruments")
|
||||
self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)
|
||||
self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)
|
||||
self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)
|
||||
|
||||
self._calendars_list = []
|
||||
self._include_fields = ()
|
||||
self._exclude_fields = ()
|
||||
|
||||
self._mode = self.ALL_MODE
|
||||
self._kwargs = {}
|
||||
|
||||
def _backup_qlib_dir(self, target_dir: Path):
|
||||
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
|
||||
|
||||
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
return []
|
||||
if is_begin_end:
|
||||
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
|
||||
return df[self.date_field_name].tolist()
|
||||
def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
|
||||
datetime_d = pd.Timestamp(datetime_d)
|
||||
return datetime_d.strftime(self.calendar_format)
|
||||
|
||||
def _get_source_data(self, file_path: Path):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
def _get_date(
|
||||
self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False
|
||||
) -> Iterable[pd.Timestamp]:
|
||||
if not isinstance(file_or_df, pd.DataFrame):
|
||||
df = self._get_source_data(file_or_df)
|
||||
else:
|
||||
df = file_or_df
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
_calendars = pd.Series()
|
||||
else:
|
||||
_calendars = df[self.date_field_name]
|
||||
|
||||
if is_begin_end and as_set:
|
||||
return (_calendars.min(), _calendars.max()), set(_calendars)
|
||||
elif is_begin_end:
|
||||
return _calendars.min(), _calendars.max()
|
||||
elif as_set:
|
||||
return set(_calendars)
|
||||
else:
|
||||
return _calendars.tolist()
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def _file_to_bin(self, file_path: Path = None):
|
||||
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
# read csv file
|
||||
df = self._get_source_data(file_path)
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
]
|
||||
cal_df.set_index(self.date_field_name, inplace=True)
|
||||
df.set_index(self.date_field_name, inplace=True)
|
||||
r_df = df.reindex(cal_df.index)
|
||||
date_index = self._calendars_list.index(r_df.index.min())
|
||||
for field in (
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return file_path.name[: -len(self.file_suffix)].strip().lower()
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
self._include_fields
|
||||
if self._include_fields
|
||||
else set(r_df.columns) - set(self._exclude_fields)
|
||||
else set(df_columns) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else r_df.columns
|
||||
):
|
||||
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
|
||||
if field not in r_df.columns:
|
||||
continue
|
||||
r = np.hstack([date_index, r_df[field]]).astype("<f")
|
||||
r.tofile(str(bin_path.resolve()))
|
||||
else df_columns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _read_calendar(calendar_path: Path):
|
||||
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
|
||||
return sorted(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
@@ -118,133 +159,303 @@ class DumpData(object):
|
||||
)
|
||||
)
|
||||
|
||||
def dump_features(
|
||||
self,
|
||||
calendar_path: str = None,
|
||||
include_fields: tuple = None,
|
||||
exclude_fields: tuple = None,
|
||||
):
|
||||
"""dump features
|
||||
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(
|
||||
instrument_path,
|
||||
sep=self.INSTRUMENTS_SEP,
|
||||
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
|
||||
)
|
||||
|
||||
Parameters
|
||||
---------
|
||||
calendar_path: str
|
||||
calendar path
|
||||
def save_calendars(self, calendars_data: list):
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
|
||||
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
include_fields: str
|
||||
dump fields
|
||||
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
|
||||
exclude_fields: str
|
||||
fields not dumped
|
||||
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
|
||||
# calendars
|
||||
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
]
|
||||
# align index
|
||||
cal_df.set_index(self.date_field_name, inplace=True)
|
||||
df.set_index(self.date_field_name, inplace=True)
|
||||
r_df = df.reindex(cal_df.index)
|
||||
return r_df
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
@staticmethod
|
||||
def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:
|
||||
return calendar_list.index(df.index.min())
|
||||
|
||||
Examples
|
||||
---------
|
||||
def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):
|
||||
if df.empty:
|
||||
logger.warning(f"{features_dir.name} data is None or empty")
|
||||
return
|
||||
# align index
|
||||
_df = self.data_merge_calendar(df, self._calendars_list)
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if self._mode == self.UPDATE_MODE:
|
||||
# update
|
||||
with bin_path.open("ab") as fp:
|
||||
np.array(_df[field]).astype("<f").tofile(fp)
|
||||
elif self._mode == self.ALL_MODE:
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
else:
|
||||
raise ValueError(f"{self._mode} cannot support!")
|
||||
|
||||
# dump all stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
# dump one stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
|
||||
"""
|
||||
logger.info("start dump features......")
|
||||
if calendar_path is not None:
|
||||
# read calendar from calendar file
|
||||
self._calendars_list = self._read_calendar(Path(calendar_path))
|
||||
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = file_or_data.iloc[0][self.symbol_field_name].lower()
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
df = self._get_source_data(file_or_data)
|
||||
else:
|
||||
raise ValueError(f"not support {type(file_or_data)}")
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
|
||||
if not self._calendars_list:
|
||||
self.dump_calendars()
|
||||
@abc.abstractmethod
|
||||
def dump(self):
|
||||
raise NotImplementedError("dump not implemented!")
|
||||
|
||||
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
|
||||
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
|
||||
|
||||
class DumpDataAll(DumpDataBase):
|
||||
def _get_all_date(self):
|
||||
logger.info("start get all date......")
|
||||
all_datetime = set()
|
||||
date_range_list = []
|
||||
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(self._file_to_bin, self.csv_files):
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
|
||||
self.csv_files, executor.map(_fun, self.csv_files)
|
||||
):
|
||||
all_datetime = all_datetime | _set_calendars
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
_begin_time = self._format_datetime(_begin_time)
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
self._kwargs["date_range_list"] = date_range_list
|
||||
logger.info("end of get all date.\n")
|
||||
|
||||
def _dump_calendars(self):
|
||||
logger.info("start dump calendars......")
|
||||
self._calendars_list = sorted(map(pd.Timestamp, self._kwargs["all_datetime_set"]))
|
||||
self.save_calendars(self._calendars_list)
|
||||
logger.info("end of calendars dump.\n")
|
||||
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
self.save_instruments(self._kwargs["date_range_list"])
|
||||
logger.info("end of instruments dump.\n")
|
||||
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump_calendars(self):
|
||||
"""dump calendars
|
||||
def dump(self):
|
||||
print("dump 2")
|
||||
self._get_all_date()
|
||||
self._dump_calendars()
|
||||
self._dump_instruments()
|
||||
self._dump_features()
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
logger.info("start dump calendars......")
|
||||
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
all_datetime = set()
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
|
||||
all_datetime = all_datetime | set(temp_datetime)
|
||||
p_bar.update()
|
||||
|
||||
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
|
||||
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
|
||||
logger.info("end of calendars dump.\n")
|
||||
|
||||
def dump_instruments(self):
|
||||
"""dump instruments
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
class DumpDataFix(DumpDataAll):
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
|
||||
_result_list = []
|
||||
_fun = partial(self._get_date_for_df, is_begin_end=True)
|
||||
with tqdm(total=len(symbol_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as execute:
|
||||
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
|
||||
if res:
|
||||
begin_time = res[0]
|
||||
end_time = res[-1]
|
||||
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
p_bar.update()
|
||||
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
|
||||
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
|
||||
logger.info("end of instruments dump.\n")
|
||||
|
||||
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
|
||||
"""dump data
|
||||
def dump(self):
|
||||
self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._old_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
) # type: dict
|
||||
self._dump_instruments()
|
||||
self._dump_features()
|
||||
|
||||
|
||||
class DumpDataUpdate(DumpDataBase):
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
symbol_field_name: str = "symbol",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_fields: str
|
||||
csv_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
backup_dir: str, default None
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "day"
|
||||
transaction frequency
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_field_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
symbol_field_name: str, default "symbol"
|
||||
symbol field name
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
|
||||
exclude_fields: str
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self.dump_calendars()
|
||||
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
|
||||
self.dump_instruments()
|
||||
super().__init__(
|
||||
csv_path,
|
||||
qlib_dir,
|
||||
backup_dir,
|
||||
freq,
|
||||
max_workers,
|
||||
date_field_name,
|
||||
file_suffix,
|
||||
symbol_field_name,
|
||||
exclude_fields,
|
||||
include_fields,
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
self._update_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
) # type: dict
|
||||
|
||||
# load all csv files
|
||||
self._all_data = self._load_all_source_data() # type: pd.DataFrame
|
||||
self._update_calendars = sorted(
|
||||
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
|
||||
)
|
||||
self._new_calendar_list = self._old_calendar_list + self._update_calendars
|
||||
|
||||
def _load_all_source_data(self):
|
||||
# NOTE: Need more memory
|
||||
logger.info("start load all source data....")
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
if self._include_fields:
|
||||
_df = pd.read_csv(file_path, usecols=self._include_fields)
|
||||
else:
|
||||
_df = pd.read_csv(file_path)
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
if df:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of load all data.\n")
|
||||
return pd.concat(all_df, sort=False)
|
||||
|
||||
def _dump_calendars(self):
|
||||
pass
|
||||
|
||||
def _dump_instruments(self):
|
||||
pass
|
||||
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
error_code = {}
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
_code = str(_code).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
self._update_instruments[_code]["end_time"] = _end
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
_dt_range["start_time"] = _start
|
||||
_dt_range["end_time"] = _end
|
||||
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
|
||||
|
||||
for _future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump(self):
|
||||
self.save_calendars(self._new_calendar_list)
|
||||
self._dump_features()
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DumpData)
|
||||
fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})
|
||||
|
||||
Reference in New Issue
Block a user