From 1a1c45981c9b7cb177e747e1a76eb05f2f35bdbe Mon Sep 17 00:00:00 2001 From: zhupr Date: Sun, 20 Dec 2020 23:07:09 +0800 Subject: [PATCH] US stock code supports Windows --- qlib/__init__.py | 5 +- qlib/data/data.py | 36 +++------ qlib/tests/__init__.py | 6 +- qlib/tests/data.py | 91 ++++++++++++++++++++--- qlib/utils/__init__.py | 53 ++++++++++++- scripts/data_collector/yahoo/README.md | 2 +- scripts/data_collector/yahoo/collector.py | 14 ++-- scripts/dump_bin.py | 86 ++++++++++----------- scripts/get_data.py | 1 + setup.py | 2 +- tests/test_get_data.py | 2 +- 11 files changed, 201 insertions(+), 97 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index 98920ed04..9fd4fffa2 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -__version__ = "0.6.1.dev" +__version__ = "0.6.1.99.dev" import os @@ -15,7 +15,7 @@ import platform import subprocess from pathlib import Path -from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path +from .utils import can_use_cache, init_instance_by_config, check_qlib_data from .workflow.utils import experiment_exit_handler # init qlib @@ -88,6 +88,7 @@ def init(default_conf="client", **kwargs): R.register(qr) # clean up experiment when python program ends experiment_exit_handler() + check_qlib_data(C) def _mount_nfs_uri(C): diff --git a/qlib/data/data.py b/qlib/data/data.py index c8e2878b4..d3701f1f1 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -15,14 +15,13 @@ import importlib import traceback import numpy as np import pandas as pd -from pathlib import Path from multiprocessing import Pool from .cache import H from ..config import C from .ops import * from ..log import get_module_logger -from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields +from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path @@ -215,23 +214,6 @@ class InstrumentProvider(abc.ABC): return cls.LIST raise ValueError(f"Unknown instrument type {inst}") - def convert_instruments(self, instrument): - _instruments_map = getattr(self, "_instruments_map", None) - if _instruments_map is None: - _df_list = [] - # FIXME: each process will read these files - for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"): - _df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) - _df_list.append(_df.iloc[:, [0, -1]]) - df = pd.concat(_df_list, sort=False) - df["inst"] = df["inst"].astype(str) - df = df.fillna(axis=1, method="ffill") - df = df.sort_values("inst").drop_duplicates(subset=["inst"], keep="first") - df["save_inst"] = df["save_inst"].astype(str) - _instruments_map = df.set_index("inst").iloc[:, 0].to_dict() - setattr(self, "_instruments_map", _instruments_map) - return _instruments_map.get(instrument, instrument) - class FeatureProvider(abc.ABC): """Feature provider class @@ -590,12 +572,16 @@ class LocalInstrumentProvider(InstrumentProvider): fname = self._uri_inst.format(market) if not os.path.exists(fname): raise ValueError("instruments not exists for market " + market) + _instruments = dict() - df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) - df["start_datetime"] = pd.to_datetime(df["start_datetime"]) - df["end_datetime"] = pd.to_datetime(df["end_datetime"]) - df["inst"] = df["inst"].astype(str) - df["save_inst"] = df.loc[:, ["inst", "save_inst"]].fillna(axis=1, method="ffill")["save_inst"].astype(str) + df = pd.read_csv( + fname, + sep="\t", + usecols=[0, 1, 2], + names=["inst", "start_datetime", "end_datetime"], + dtype={"inst": str}, + parse_dates=["start_datetime", "end_datetime"], + ) for row in df.itertuples(index=False): _instruments.setdefault(row[0], []).append((row[1], row[2])) return _instruments @@ -652,7 +638,7 @@ class LocalFeatureProvider(FeatureProvider): def feature(self, instrument, field, start_index, end_index, freq): # validate field = str(field).lower()[1:] - instrument = Inst.convert_instruments(instrument) + instrument = code_to_fname(instrument) uri_data = self._uri_data.format(instrument.lower(), field, freq) if not os.path.exists(uri_data): get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index a1b33a2a2..af8dc6c1a 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -15,6 +15,10 @@ class TestAutoData(unittest.TestCase): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data( - name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri + name="qlib_data_simple", + region="cn", + interval="1d", + target_dir=provider_uri, + delete_old=False, ) init(provider_uri=provider_uri, region=REG_CN) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 66bfb0e29..a1e97aae4 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -1,14 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +import re +import qlib +import shutil import zipfile import requests +import datetime from tqdm import tqdm from pathlib import Path from loguru import logger class GetData: + DATASET_VERSION = "v1" REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads" + QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip" def __init__(self, delete_zip_file=False): """ @@ -20,13 +27,24 @@ class GetData: """ self.delete_zip_file = delete_zip_file - def _download_data(self, file_name: str, target_dir: [Path, str]): + def normalize_dataset_version(self, dataset_version: str = None): + if dataset_version is None: + dataset_version = self.DATASET_VERSION + return dataset_version + + def merge_remote_url(self, file_name: str, dataset_version: str = None): + return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}" + + def _download_data( + self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None + ): target_dir = Path(target_dir).expanduser() target_dir.mkdir(exist_ok=True, parents=True) + # saved file name + _target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name + target_path = target_dir.joinpath(_target_file_name) - url = f"{self.REMOTE_URL}/{file_name}" - target_path = target_dir.joinpath(file_name) - + url = self.merge_remote_url(file_name, dataset_version) resp = requests.get(url, stream=True) if resp.status_code != 200: raise requests.exceptions.HTTPError() @@ -42,19 +60,59 @@ class GetData: fp.write(chuck) p_bar.update(chuck_size) - self._unzip(target_path, target_dir) + self._unzip(target_path, target_dir, delete_old) if self.delete_zip_file: target_path.unlike() + def check_dataset(self, file_name: str, dataset_version: str = None): + url = self.merge_remote_url(file_name, dataset_version) + resp = requests.get(url, stream=True) + status = True + if resp.status_code == 404: + status = False + return status + @staticmethod - def _unzip(file_path: Path, target_dir: Path): + def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True): + if delete_old: + logger.warning( + f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}" + ) + GetData._delete_qlib_data(target_dir) logger.info(f"{file_path} unzipping......") with zipfile.ZipFile(str(file_path.resolve()), "r") as zp: for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) + @staticmethod + def _delete_qlib_data(file_dir: Path): + logger.info(f"delete {file_dir}") + rm_dirs = [] + for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]: + _p = file_dir.joinpath(_name) + if _p.exists(): + rm_dirs.append(str(_p.resolve())) + if rm_dirs: + flag = input( + f"Will be deleted: " + f"\n\t{rm_dirs}" + f"\nIf you do not need to delete {file_dir}, please change the <--target_dir>" + f"\nAre you sure you want to delete, yes(Y/y), no (N/n):" + ) + if str(flag) not in ["Y", "y"]: + exit() + for _p in rm_dirs: + logger.warning(f"delete: {_p}") + shutil.rmtree(_p) + def qlib_data( - self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn" + self, + name="qlib_data", + target_dir="~/.qlib/qlib_data/cn_data", + version=None, + interval="1d", + region="cn", + delete_old=True, ): """download cn qlib data from remote @@ -65,20 +123,31 @@ class GetData: name: str dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data version: str - data version, value from [v0, v1, ..., latest], by default latest + data version, value from [v1, ...], by default None(use script to specify version) interval: str data freq, value from [1d], by default 1d region: str data region, value from [cn, us], by default cn + delete_old: bool + delete an existing directory, by default True Examples --------- - python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn + python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn ------- """ - file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" - self._download_data(file_name.lower(), target_dir) + qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) + + def _get_file_name(v): + return self.QLIB_DATA_NAME.format( + dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v + ) + + file_name = _get_file_name(qlib_version) + if not self.check_dataset(file_name, version): + file_name = _get_file_name("latest") + self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version) def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): """download cn csv data from remote diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index a5a4b4a56..c75d6db96 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -26,6 +26,7 @@ import pandas as pd from pathlib import Path from typing import Union, Tuple +from .. import __version__ as qlib_version from ..config import C from ..log import get_module_logger @@ -643,15 +644,28 @@ def exists_qlib_data(qlib_dir): # check instruments code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") - df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) - df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill") - miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names) + miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): return False return True +def check_qlib_data(qlib_config): + inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments") + for _p in inst_dir.glob("*.txt"): + try: + assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, ( + f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:" + f"\n\tIf you are using the data provided by qlib: " + f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset" + f"\n\tIf you are using your own data, please dump the data again: " + f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format" + ) + except AssertionError: + raise + + def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: """ make the df index sorted @@ -742,3 +756,36 @@ def load_dataset(path_or_obj): elif extension == ".csv": return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1]) raise ValueError(f"unsupported file type `{extension}`") + + +def code_to_fname(code: str): + """stock code to file name + + Parameters + ---------- + code: str + """ + # NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created + # reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows + replace_names = ["CON", "PRN", "AUX", "NUL"] + replace_names += [f"COM{i}" for i in range(10)] + replace_names += [f"LPT{i}" for i in range(10)] + + prefix = "_qlib_" + if str(code).upper() in replace_names: + code = prefix + str(code) + + return code + + +def fname_to_code(fname: str): + """file name to stock code + + Parameters + ---------- + fname: str + """ + prefix = "_qlib_" + if fname.startswith(prefix): + fname = fname.lstrip(prefix) + return fname diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 1e65aeaed..c7442a553 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -20,7 +20,7 @@ pip install -r requirements.txt ### Download data and Normalize data ```bash -python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize ``` ### Download Data diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 0d41251f1..19c9dcdae 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -18,6 +18,7 @@ from tqdm import tqdm from loguru import logger from yahooquery import Ticker from dateutil.tz import tzlocal +from qlib.utils import code_to_fname CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) @@ -40,7 +41,7 @@ class YahooCollector: end=None, interval="1d", max_workers=4, - max_collector_count=5, + max_collector_count=2, delay=0, check_data_length: bool = False, limit_nums: int = None, @@ -55,7 +56,7 @@ class YahooCollector: max_workers: int workers, default 4 max_collector_count: int - default 5 + default 2 delay: float time.sleep(delay), default 0 interval: str @@ -147,11 +148,10 @@ class YahooCollector: stock_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol if stock_path.exists(): - with stock_path.open("a") as fp: - df.to_csv(fp, index=False, header=False) + _temp_df = pd.read_csv(stock_path, nrows=0) + df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a") else: - with stock_path.open("w") as fp: - df.to_csv(fp, index=False) + df.to_csv(stock_path, index=False, mode="w") def _save_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -350,7 +350,7 @@ class YahooCollectorUS(YahooCollector): pass def normalize_symbol(self, symbol): - return symbol.upper() + return code_to_fname(symbol).upper() @property def _timezone(self): diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index bdc227029..4811fd486 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd from tqdm import tqdm from loguru import logger +from qlib.utils import fname_to_code, code_to_fname class DumpDataBase: @@ -27,7 +28,6 @@ class DumpDataBase: HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S" INSTRUMENTS_SEP = "\t" INSTRUMENTS_FILE_NAME = "all.txt" - SAVE_INST_FIELD = "save_inst" UPDATE_MODE = "update" ALL_MODE = "all" @@ -45,7 +45,6 @@ class DumpDataBase: exclude_fields: str = "", include_fields: str = "", limit_nums: int = None, - inst_prefix: str = "", ): """ @@ -73,9 +72,6 @@ class DumpDataBase: fields not dumped limit_nums: int Use when debugging, default None - inst_prefix: str - add a column to the instruments file and record the saved instrument name, - the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix. """ csv_path = Path(csv_path).expanduser() if isinstance(exclude_fields, str): @@ -84,7 +80,6 @@ class DumpDataBase: include_fields = include_fields.split(",") self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) - self._inst_prefix = inst_prefix.strip() self.file_suffix = file_suffix self.symbol_field_name = symbol_field_name self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) @@ -145,7 +140,7 @@ class DumpDataBase: return df def get_symbol_from_file(self, file_path: Path) -> str: - return file_path.name[: -len(self.file_suffix)].strip().lower() + return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower()) def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: return ( @@ -173,7 +168,6 @@ class DumpDataBase: self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD, - self.SAVE_INST_FIELD, ], ) @@ -190,13 +184,11 @@ class DumpDataBase: instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve()) if isinstance(instruments_data, pd.DataFrame): _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD] - if self._inst_prefix: - _df_fields.append(self.SAVE_INST_FIELD) - instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply( - lambda x: f"{self._inst_prefix}{x}" - ) instruments_data = instruments_data.loc[:, _df_fields] - instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP) + instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply( + lambda x: fname_to_code(x.lower()).upper() + ) + instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False) else: np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8") @@ -223,26 +215,26 @@ class DumpDataBase: logger.warning(f"{features_dir.name} data is None or empty") return # align index - _df = self.data_merge_calendar(df, self._calendars_list) + _df = self.data_merge_calendar(df, calendar_list) + # used when creating a bin file date_index = self.get_datetime_index(_df, calendar_list) for field in self.get_dump_fields(_df.columns): bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}") if field not in _df.columns: continue - if self._mode == self.UPDATE_MODE: + if bin_path.exists() and self._mode == self.UPDATE_MODE: # update with bin_path.open("ab") as fp: np.array(_df[field]).astype("