1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

US stock code supports Windows

This commit is contained in:
zhupr
2020-12-20 23:07:09 +08:00
parent df556532d0
commit 1a1c45981c
11 changed files with 201 additions and 97 deletions

View File

@@ -14,6 +14,7 @@ import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import fname_to_code, code_to_fname
class DumpDataBase:
@@ -27,7 +28,6 @@ class DumpDataBase:
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
SAVE_INST_FIELD = "save_inst"
UPDATE_MODE = "update"
ALL_MODE = "all"
@@ -45,7 +45,6 @@ class DumpDataBase:
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
inst_prefix: str = "",
):
"""
@@ -73,9 +72,6 @@ class DumpDataBase:
fields not dumped
limit_nums: int
Use when debugging, default None
inst_prefix: str
add a column to the instruments file and record the saved instrument name,
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
@@ -84,7 +80,6 @@ class DumpDataBase:
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self._inst_prefix = inst_prefix.strip()
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
@@ -145,7 +140,7 @@ class DumpDataBase:
return df
def get_symbol_from_file(self, file_path: Path) -> str:
return file_path.name[: -len(self.file_suffix)].strip().lower()
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
@@ -173,7 +168,6 @@ class DumpDataBase:
self.symbol_field_name,
self.INSTRUMENTS_START_FIELD,
self.INSTRUMENTS_END_FIELD,
self.SAVE_INST_FIELD,
],
)
@@ -190,13 +184,11 @@ class DumpDataBase:
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
if self._inst_prefix:
_df_fields.append(self.SAVE_INST_FIELD)
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
lambda x: f"{self._inst_prefix}{x}"
)
instruments_data = instruments_data.loc[:, _df_fields]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(
lambda x: fname_to_code(x.lower()).upper()
)
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
@@ -223,26 +215,26 @@ class DumpDataBase:
logger.warning(f"{features_dir.name} data is None or empty")
return
# align index
_df = self.data_merge_calendar(df, self._calendars_list)
_df = self.data_merge_calendar(df, calendar_list)
# used when creating a bin file
date_index = self.get_datetime_index(_df, calendar_list)
for field in self.get_dump_fields(_df.columns):
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
if field not in _df.columns:
continue
if self._mode == self.UPDATE_MODE:
if bin_path.exists() and self._mode == self.UPDATE_MODE:
# update
with bin_path.open("ab") as fp:
np.array(_df[field]).astype("<f").tofile(fp)
elif self._mode == self.ALL_MODE:
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
else:
raise ValueError(f"{self._mode} cannot support!")
# append; self._mode == self.ALL_MODE or not bin_path.exists()
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
if isinstance(file_or_data, pd.DataFrame):
if file_or_data.empty:
return
code = file_or_data.iloc[0][self.symbol_field_name].lower()
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
df = file_or_data
elif isinstance(file_or_data, Path):
code = self.get_symbol_from_file(file_or_data)
@@ -253,8 +245,7 @@ class DumpDataBase:
logger.warning(f"{code} data is None or empty")
return
# features save dir
code = self._inst_prefix + code if self._inst_prefix else code
features_dir = self._features_dir.joinpath(code)
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
@@ -283,8 +274,6 @@ class DumpDataAll(DumpDataBase):
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
_inst_fields = [symbol.upper(), _begin_time, _end_time]
if self._inst_prefix:
_inst_fields.append(self._inst_prefix + symbol.upper())
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
@@ -323,12 +312,18 @@ class DumpDataFix(DumpDataAll):
def _dump_instruments(self):
logger.info("start dump instruments......")
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
new_stock_files = sorted(
filter(
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
not in self._old_instruments,
self.csv_files,
)
)
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = self.get_symbol_from_file(file_path).upper()
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
@@ -406,10 +401,10 @@ class DumpDataUpdate(DumpDataBase):
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
self._update_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
self._update_instruments = (
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
.set_index([self.symbol_field_name])
.to_dict(orient="index")
) # type: dict
# load all csv files
@@ -425,10 +420,7 @@ class DumpDataUpdate(DumpDataBase):
all_df = []
def _read_csv(file_path: Path):
if self._include_fields:
_df = pd.read_csv(file_path, usecols=self._include_fields)
else:
_df = pd.read_csv(file_path)
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df
@@ -436,7 +428,7 @@ class DumpDataUpdate(DumpDataBase):
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
if df:
if not df.empty:
all_df.append(df)
p_bar.update()
@@ -455,25 +447,27 @@ class DumpDataUpdate(DumpDataBase):
with ProcessPoolExecutor(max_workers=self.works) as executor:
futures = {}
for _code, _df in self._all_data.groupby(self.symbol_field_name):
_code = str(_code).upper()
_code = fname_to_code(str(_code).lower()).upper()
_start, _end = self._get_date(_df, is_begin_end=True)
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
self._update_instruments[_code]["end_time"] = _end
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())
_dt_range["start_time"] = _start
_dt_range["end_time"] = _end
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
for _future in tqdm(as_completed(futures)):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
with tqdm(total=len(futures)) as p_bar:
for _future in as_completed(futures):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
p_bar.update()
logger.info(f"dump bin errors {error_code}")
logger.info("end of features dump.\n")
@@ -481,7 +475,9 @@ class DumpDataUpdate(DumpDataBase):
def dump(self):
self.save_calendars(self._new_calendar_list)
self._dump_features()
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
df.index.names = [self.symbol_field_name]
self.save_instruments(df.reset_index())
if __name__ == "__main__":