mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix us instruments
This commit is contained in:
@@ -20,7 +20,6 @@ python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
### Downlaod US Data
|
||||
|
||||
> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
@@ -24,6 +24,7 @@ class IndexBase:
|
||||
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
"""
|
||||
@@ -196,7 +197,11 @@ class IndexBase:
|
||||
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
inst_df = new_df.loc[:, instruments_columns]
|
||||
_inst_prefix = self.INST_PREFIX.strip()
|
||||
if _inst_prefix:
|
||||
inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}")
|
||||
inst_df.to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
|
||||
@@ -33,6 +33,10 @@ WIKI_INDEX_NAME_MAP = {
|
||||
|
||||
|
||||
class WIKIIndex(IndexBase):
|
||||
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = "_"
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
|
||||
@@ -184,9 +184,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
_US_SYMBOLS = sorted(
|
||||
set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
|
||||
)
|
||||
|
||||
def _format(s_):
|
||||
s_ = s_.replace(".", "-")
|
||||
s_ = s_.strip("$")
|
||||
s_ = s_.strip("*")
|
||||
return s_
|
||||
|
||||
_US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
|
||||
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class DumpDataBase:
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
SAVE_INST_FIELD = "save_inst"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
@@ -44,6 +45,7 @@ class DumpDataBase:
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
inst_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -71,6 +73,9 @@ class DumpDataBase:
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
inst_prefix: str
|
||||
add a column to the instruments file and record the saved instrument name,
|
||||
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
@@ -79,6 +84,7 @@ class DumpDataBase:
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self._inst_prefix = inst_prefix.strip()
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
@@ -160,12 +166,19 @@ class DumpDataBase:
|
||||
)
|
||||
|
||||
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(
|
||||
df = pd.read_csv(
|
||||
instrument_path,
|
||||
sep=self.INSTRUMENTS_SEP,
|
||||
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
|
||||
names=[
|
||||
self.symbol_field_name,
|
||||
self.INSTRUMENTS_START_FIELD,
|
||||
self.INSTRUMENTS_END_FIELD,
|
||||
self.SAVE_INST_FIELD,
|
||||
],
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def save_calendars(self, calendars_data: list):
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
@@ -176,7 +189,13 @@ class DumpDataBase:
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
|
||||
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
|
||||
if self._inst_prefix:
|
||||
_df_fields.append(self.SAVE_INST_FIELD)
|
||||
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: f"{self._inst_prefix}{x}"
|
||||
)
|
||||
instruments_data = instruments_data.loc[:, _df_fields]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
@@ -234,6 +253,7 @@ class DumpDataBase:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
code = self._inst_prefix + code if self._inst_prefix else code
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
@@ -262,7 +282,10 @@ class DumpDataAll(DumpDataBase):
|
||||
_begin_time = self._format_datetime(_begin_time)
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
|
||||
_inst_fields = [symbol.upper(), _begin_time, _end_time]
|
||||
if self._inst_prefix:
|
||||
_inst_fields.append(self._inst_prefix + symbol.upper())
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
self._kwargs["date_range_list"] = date_range_list
|
||||
|
||||
@@ -79,9 +79,6 @@ class GetData:
|
||||
-------
|
||||
|
||||
"""
|
||||
# TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system
|
||||
if region.lower() == "us":
|
||||
logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system")
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user