1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix us instruments

This commit is contained in:
zhupr
2020-11-22 15:26:27 +08:00
parent 8958656222
commit bac16060ff
9 changed files with 70 additions and 26 deletions

View File

@@ -27,6 +27,7 @@ class DumpDataBase:
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
SAVE_INST_FIELD = "save_inst"
UPDATE_MODE = "update"
ALL_MODE = "all"
@@ -44,6 +45,7 @@ class DumpDataBase:
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
inst_prefix: str = "",
):
"""
@@ -71,6 +73,9 @@ class DumpDataBase:
fields not dumped
limit_nums: int
Use when debugging, default None
inst_prefix: str
add a column to the instruments file and record the saved instrument name,
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
@@ -79,6 +84,7 @@ class DumpDataBase:
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self._inst_prefix = inst_prefix.strip()
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
@@ -160,12 +166,19 @@ class DumpDataBase:
)
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
return pd.read_csv(
df = pd.read_csv(
instrument_path,
sep=self.INSTRUMENTS_SEP,
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
names=[
self.symbol_field_name,
self.INSTRUMENTS_START_FIELD,
self.INSTRUMENTS_END_FIELD,
self.SAVE_INST_FIELD,
],
)
return df
def save_calendars(self, calendars_data: list):
self._calendars_dir.mkdir(parents=True, exist_ok=True)
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
@@ -176,7 +189,13 @@ class DumpDataBase:
self._instruments_dir.mkdir(parents=True, exist_ok=True)
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
if self._inst_prefix:
_df_fields.append(self.SAVE_INST_FIELD)
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
lambda x: f"{self._inst_prefix}{x}"
)
instruments_data = instruments_data.loc[:, _df_fields]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
@@ -234,6 +253,7 @@ class DumpDataBase:
logger.warning(f"{code} data is None or empty")
return
# features save dir
code = self._inst_prefix + code if self._inst_prefix else code
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
@@ -262,7 +282,10 @@ class DumpDataAll(DumpDataBase):
_begin_time = self._format_datetime(_begin_time)
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
_inst_fields = [symbol.upper(), _begin_time, _end_time]
if self._inst_prefix:
_inst_fields.append(self._inst_prefix + symbol.upper())
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
self._kwargs["date_range_list"] = date_range_list