diff --git a/qlib/config.py b/qlib/config.py index d05161772..ac9c3ba65 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -155,6 +155,7 @@ MODE_CONF = { # cache "expression_cache": "DiskExpressionCache", "dataset_cache": "DiskDatasetCache", + "mount_path": None, }, "client": { # data provider config diff --git a/qlib/data/data.py b/qlib/data/data.py index 8331b1802..8fac9edec 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -15,6 +15,7 @@ import importlib import traceback import numpy as np import pandas as pd +from pathlib import Path from multiprocessing import Pool from .cache import H @@ -211,6 +212,20 @@ class InstrumentProvider(abc.ABC): return cls.LIST raise ValueError(f"Unknown instrument type {inst}") + def convert_instruments(self, instrument): + _instruments_map = getattr(self, "_instruments_map", None) + if _instruments_map is None: + _df_list = [] + # FIXME: each process will read these files + for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"): + _df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + _df_list.append(_df.iloc[:, [0, -1]]) + df = pd.concat(_df_list, sort=False).sort_values("save_inst") + df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill") + _instruments_map = df.set_index("inst").iloc[:, 0].to_dict() + setattr(self, "_instruments_map", _instruments_map) + return _instruments_map.get(instrument, instrument) + class FeatureProvider(abc.ABC): """Feature provider class @@ -570,19 +585,11 @@ class LocalInstrumentProvider(InstrumentProvider): if not os.path.exists(fname): raise ValueError("instruments not exists for market " + market) _instruments = dict() - with open(fname) as f: - for line in f: - inst_time = line.strip().split() - inst = inst_time[0] - if len(inst_time) == 3: - # `day` - begin = inst_time[1] - end = inst_time[2] - elif len(inst_time) == 5: - # `1min` - begin = inst_time[1] + " " + inst_time[2] - end = inst_time[3] + " " + inst_time[4] - _instruments.setdefault(inst, []).append((pd.Timestamp(begin), pd.Timestamp(end))) + df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + df["start_datetime"] = pd.to_datetime(df["start_datetime"]) + df["end_datetime"] = pd.to_datetime(df["end_datetime"]) + for row in df.itertuples(index=False): + _instruments.setdefault(row[0], []).append((row[1], row[2])) return _instruments def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): @@ -637,6 +644,7 @@ class LocalFeatureProvider(FeatureProvider): def feature(self, instrument, field, start_index, end_index, freq): # validate field = str(field).lower()[1:] + instrument = Inst.convert_instruments(instrument) uri_data = self._uri_data.format(instrument.lower(), field, freq) if not os.path.exists(uri_data): get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index f32cceba3..79fd6fe5c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -613,7 +613,9 @@ def exists_qlib_data(qlib_dir): # check instruments code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") - miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) + df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill") + miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): return False diff --git a/scripts/README.md b/scripts/README.md index 98b01e0c3..88ebdc680 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -20,7 +20,6 @@ python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn ### Downlaod US Data -> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows ```bash python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py index c5f3854fd..300e6b625 100644 --- a/scripts/data_collector/index.py +++ b/scripts/data_collector/index.py @@ -24,6 +24,7 @@ class IndexBase: INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD] REMOVE = "remove" ADD = "add" + INST_PREFIX = "" def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): """ @@ -196,7 +197,11 @@ class IndexBase: _tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns) new_df = new_df.append(_tmp_df, sort=False) - new_df.loc[:, instruments_columns].to_csv( + inst_df = new_df.loc[:, instruments_columns] + _inst_prefix = self.INST_PREFIX.strip() + if _inst_prefix: + inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}") + inst_df.to_csv( self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None ) logger.info(f"parse {self.index_name.lower()} companies finished.") diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py index ea1e974a0..0641437e0 100644 --- a/scripts/data_collector/us_index/collector.py +++ b/scripts/data_collector/us_index/collector.py @@ -33,6 +33,10 @@ WIKI_INDEX_NAME_MAP = { class WIKIIndex(IndexBase): + # NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix + # https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows + INST_PREFIX = "_" + def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): super(WIKIIndex, self).__init__( index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 855569642..2cf9f4c6a 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -184,9 +184,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: names=["symbol", "start_date", "end_date"], ) _all_symbols += ins_df["symbol"].unique().tolist() - _US_SYMBOLS = sorted( - set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))) - ) + + def _format(s_): + s_ = s_.replace(".", "-") + s_ = s_.strip("$") + s_ = s_.strip("*") + return s_ + + _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) return _US_SYMBOLS diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 2e44c454e..2bca4f037 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -27,6 +27,7 @@ class DumpDataBase: HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S" INSTRUMENTS_SEP = "\t" INSTRUMENTS_FILE_NAME = "all.txt" + SAVE_INST_FIELD = "save_inst" UPDATE_MODE = "update" ALL_MODE = "all" @@ -44,6 +45,7 @@ class DumpDataBase: exclude_fields: str = "", include_fields: str = "", limit_nums: int = None, + inst_prefix: str = "", ): """ @@ -71,6 +73,9 @@ class DumpDataBase: fields not dumped limit_nums: int Use when debugging, default None + inst_prefix: str + add a column to the instruments file and record the saved instrument name, + the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix. """ csv_path = Path(csv_path).expanduser() if isinstance(exclude_fields, str): @@ -79,6 +84,7 @@ class DumpDataBase: include_fields = include_fields.split(",") self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self._inst_prefix = inst_prefix.strip() self.file_suffix = file_suffix self.symbol_field_name = symbol_field_name self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) @@ -160,12 +166,19 @@ class DumpDataBase: ) def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: - return pd.read_csv( + df = pd.read_csv( instrument_path, sep=self.INSTRUMENTS_SEP, - names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD], + names=[ + self.symbol_field_name, + self.INSTRUMENTS_START_FIELD, + self.INSTRUMENTS_END_FIELD, + self.SAVE_INST_FIELD, + ], ) + return df + def save_calendars(self, calendars_data: list): self._calendars_dir.mkdir(parents=True, exist_ok=True) calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) @@ -176,7 +189,13 @@ class DumpDataBase: self._instruments_dir.mkdir(parents=True, exist_ok=True) instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve()) if isinstance(instruments_data, pd.DataFrame): - instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]] + _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD] + if self._inst_prefix: + _df_fields.append(self.SAVE_INST_FIELD) + instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply( + lambda x: f"{self._inst_prefix}{x}" + ) + instruments_data = instruments_data.loc[:, _df_fields] instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP) else: np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8") @@ -234,6 +253,7 @@ class DumpDataBase: logger.warning(f"{code} data is None or empty") return # features save dir + code = self._inst_prefix + code if self._inst_prefix else code features_dir = self._features_dir.joinpath(code) features_dir.mkdir(parents=True, exist_ok=True) self._data_to_bin(df, calendar_list, features_dir) @@ -262,7 +282,10 @@ class DumpDataAll(DumpDataBase): _begin_time = self._format_datetime(_begin_time) _end_time = self._format_datetime(_end_time) symbol = self.get_symbol_from_file(file_path) - date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}") + _inst_fields = [symbol.upper(), _begin_time, _end_time] + if self._inst_prefix: + _inst_fields.append(self._inst_prefix + symbol.upper()) + date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}") p_bar.update() self._kwargs["all_datetime_set"] = all_datetime self._kwargs["date_range_list"] = date_range_list diff --git a/scripts/get_data.py b/scripts/get_data.py index 4c0595238..f4dba1474 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -79,9 +79,6 @@ class GetData: ------- """ - # TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system - if region.lower() == "us": - logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system") file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" self._download_data(file_name.lower(), target_dir)