mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
US stock code supports Windows
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.1.dev"
|
||||
__version__ = "0.6.1.99.dev"
|
||||
|
||||
|
||||
import os
|
||||
@@ -15,7 +15,7 @@ import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
|
||||
from .utils import can_use_cache, init_instance_by_config, check_qlib_data
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
|
||||
# init qlib
|
||||
@@ -88,6 +88,7 @@ def init(default_conf="client", **kwargs):
|
||||
R.register(qr)
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
check_qlib_data(C)
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
|
||||
@@ -15,14 +15,13 @@ import importlib
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from multiprocessing import Pool
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .ops import *
|
||||
from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
@@ -215,23 +214,6 @@ class InstrumentProvider(abc.ABC):
|
||||
return cls.LIST
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
def convert_instruments(self, instrument):
|
||||
_instruments_map = getattr(self, "_instruments_map", None)
|
||||
if _instruments_map is None:
|
||||
_df_list = []
|
||||
# FIXME: each process will read these files
|
||||
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
|
||||
_df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
_df_list.append(_df.iloc[:, [0, -1]])
|
||||
df = pd.concat(_df_list, sort=False)
|
||||
df["inst"] = df["inst"].astype(str)
|
||||
df = df.fillna(axis=1, method="ffill")
|
||||
df = df.sort_values("inst").drop_duplicates(subset=["inst"], keep="first")
|
||||
df["save_inst"] = df["save_inst"].astype(str)
|
||||
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
|
||||
setattr(self, "_instruments_map", _instruments_map)
|
||||
return _instruments_map.get(instrument, instrument)
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC):
|
||||
"""Feature provider class
|
||||
@@ -590,12 +572,16 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
fname = self._uri_inst.format(market)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
|
||||
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
|
||||
df["inst"] = df["inst"].astype(str)
|
||||
df["save_inst"] = df.loc[:, ["inst", "save_inst"]].fillna(axis=1, method="ffill")["save_inst"].astype(str)
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=["inst", "start_datetime", "end_datetime"],
|
||||
dtype={"inst": str},
|
||||
parse_dates=["start_datetime", "end_datetime"],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
@@ -652,7 +638,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = Inst.convert_instruments(instrument)
|
||||
instrument = code_to_fname(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
|
||||
@@ -15,6 +15,10 @@ class TestAutoData(unittest.TestCase):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import qlib
|
||||
import shutil
|
||||
import zipfile
|
||||
import requests
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v1"
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -20,13 +27,24 @@ class GetData:
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
def normalize_dataset_version(self, dataset_version: str = None):
|
||||
if dataset_version is None:
|
||||
dataset_version = self.DATASET_VERSION
|
||||
return dataset_version
|
||||
|
||||
def merge_remote_url(self, file_name: str, dataset_version: str = None):
|
||||
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"
|
||||
|
||||
def _download_data(
|
||||
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
|
||||
):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
# saved file name
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
@@ -42,19 +60,59 @@ class GetData:
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
self._unzip(target_path, target_dir, delete_old)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
|
||||
def check_dataset(self, file_name: str, dataset_version: str = None):
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True)
|
||||
status = True
|
||||
if resp.status_code == 404:
|
||||
status = False
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
|
||||
if delete_old:
|
||||
logger.warning(
|
||||
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
|
||||
)
|
||||
GetData._delete_qlib_data(target_dir)
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
@staticmethod
|
||||
def _delete_qlib_data(file_dir: Path):
|
||||
logger.info(f"delete {file_dir}")
|
||||
rm_dirs = []
|
||||
for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]:
|
||||
_p = file_dir.joinpath(_name)
|
||||
if _p.exists():
|
||||
rm_dirs.append(str(_p.resolve()))
|
||||
if rm_dirs:
|
||||
flag = input(
|
||||
f"Will be deleted: "
|
||||
f"\n\t{rm_dirs}"
|
||||
f"\nIf you do not need to delete {file_dir}, please change the <--target_dir>"
|
||||
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
exit()
|
||||
for _p in rm_dirs:
|
||||
logger.warning(f"delete: {_p}")
|
||||
shutil.rmtree(_p)
|
||||
|
||||
def qlib_data(
|
||||
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
|
||||
self,
|
||||
name="qlib_data",
|
||||
target_dir="~/.qlib/qlib_data/cn_data",
|
||||
version=None,
|
||||
interval="1d",
|
||||
region="cn",
|
||||
delete_old=True,
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
@@ -65,20 +123,31 @@ class GetData:
|
||||
name: str
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
data version, value from [v1, ...], by default None(use script to specify version)
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name(v):
|
||||
return self.QLIB_DATA_NAME.format(
|
||||
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
|
||||
)
|
||||
|
||||
file_name = _get_file_name(qlib_version)
|
||||
if not self.check_dataset(file_name, version):
|
||||
file_name = _get_file_name("latest")
|
||||
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
@@ -26,6 +26,7 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
|
||||
from .. import __version__ as qlib_version
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
|
||||
@@ -643,15 +644,28 @@ def exists_qlib_data(qlib_dir):
|
||||
# check instruments
|
||||
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
|
||||
_instrument = instruments_dir.joinpath("all.txt")
|
||||
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
|
||||
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
|
||||
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
|
||||
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_qlib_data(qlib_config):
|
||||
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
|
||||
for _p in inst_dir.glob("*.txt"):
|
||||
try:
|
||||
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
|
||||
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
|
||||
f"\n\tIf you are using the data provided by qlib: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
|
||||
f"\n\tIf you are using your own data, please dump the data again: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
|
||||
)
|
||||
except AssertionError:
|
||||
raise
|
||||
|
||||
|
||||
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
"""
|
||||
make the df index sorted
|
||||
@@ -742,3 +756,36 @@ def load_dataset(path_or_obj):
|
||||
elif extension == ".csv":
|
||||
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
||||
raise ValueError(f"unsupported file type `{extension}`")
|
||||
|
||||
|
||||
def code_to_fname(code: str):
|
||||
"""stock code to file name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code: str
|
||||
"""
|
||||
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
|
||||
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
|
||||
replace_names = ["CON", "PRN", "AUX", "NUL"]
|
||||
replace_names += [f"COM{i}" for i in range(10)]
|
||||
replace_names += [f"LPT{i}" for i in range(10)]
|
||||
|
||||
prefix = "_qlib_"
|
||||
if str(code).upper() in replace_names:
|
||||
code = prefix + str(code)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def fname_to_code(fname: str):
|
||||
"""file name to stock code
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
"""
|
||||
prefix = "_qlib_"
|
||||
if fname.startswith(prefix):
|
||||
fname = fname.lstrip(prefix)
|
||||
return fname
|
||||
|
||||
@@ -20,7 +20,7 @@ pip install -r requirements.txt
|
||||
|
||||
### Download data and Normalize data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize
|
||||
```
|
||||
|
||||
### Download Data
|
||||
|
||||
@@ -18,6 +18,7 @@ from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -40,7 +41,7 @@ class YahooCollector:
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=5,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
@@ -55,7 +56,7 @@ class YahooCollector:
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 5
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
@@ -147,11 +148,10 @@ class YahooCollector:
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=False)
|
||||
_temp_df = pd.read_csv(stock_path, nrows=0)
|
||||
df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a")
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
df.to_csv(stock_path, index=False, mode="w")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
@@ -350,7 +350,7 @@ class YahooCollectorUS(YahooCollector):
|
||||
pass
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
return code_to_fname(symbol).upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
|
||||
@@ -14,6 +14,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname
|
||||
|
||||
|
||||
class DumpDataBase:
|
||||
@@ -27,7 +28,6 @@ class DumpDataBase:
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
SAVE_INST_FIELD = "save_inst"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
@@ -45,7 +45,6 @@ class DumpDataBase:
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
inst_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -73,9 +72,6 @@ class DumpDataBase:
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
inst_prefix: str
|
||||
add a column to the instruments file and record the saved instrument name,
|
||||
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
@@ -84,7 +80,6 @@ class DumpDataBase:
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self._inst_prefix = inst_prefix.strip()
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
@@ -145,7 +140,7 @@ class DumpDataBase:
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return file_path.name[: -len(self.file_suffix)].strip().lower()
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
@@ -173,7 +168,6 @@ class DumpDataBase:
|
||||
self.symbol_field_name,
|
||||
self.INSTRUMENTS_START_FIELD,
|
||||
self.INSTRUMENTS_END_FIELD,
|
||||
self.SAVE_INST_FIELD,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -190,13 +184,11 @@ class DumpDataBase:
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
|
||||
if self._inst_prefix:
|
||||
_df_fields.append(self.SAVE_INST_FIELD)
|
||||
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: f"{self._inst_prefix}{x}"
|
||||
)
|
||||
instruments_data = instruments_data.loc[:, _df_fields]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: fname_to_code(x.lower()).upper()
|
||||
)
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
|
||||
@@ -223,26 +215,26 @@ class DumpDataBase:
|
||||
logger.warning(f"{features_dir.name} data is None or empty")
|
||||
return
|
||||
# align index
|
||||
_df = self.data_merge_calendar(df, self._calendars_list)
|
||||
_df = self.data_merge_calendar(df, calendar_list)
|
||||
# used when creating a bin file
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if self._mode == self.UPDATE_MODE:
|
||||
if bin_path.exists() and self._mode == self.UPDATE_MODE:
|
||||
# update
|
||||
with bin_path.open("ab") as fp:
|
||||
np.array(_df[field]).astype("<f").tofile(fp)
|
||||
elif self._mode == self.ALL_MODE:
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
else:
|
||||
raise ValueError(f"{self._mode} cannot support!")
|
||||
# append; self._mode == self.ALL_MODE or not bin_path.exists()
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
|
||||
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = file_or_data.iloc[0][self.symbol_field_name].lower()
|
||||
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
@@ -253,8 +245,7 @@ class DumpDataBase:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
code = self._inst_prefix + code if self._inst_prefix else code
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
|
||||
@@ -283,8 +274,6 @@ class DumpDataAll(DumpDataBase):
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
_inst_fields = [symbol.upper(), _begin_time, _end_time]
|
||||
if self._inst_prefix:
|
||||
_inst_fields.append(self._inst_prefix + symbol.upper())
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
@@ -323,12 +312,18 @@ class DumpDataFix(DumpDataAll):
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
|
||||
new_stock_files = sorted(
|
||||
filter(
|
||||
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
|
||||
not in self._old_instruments,
|
||||
self.csv_files,
|
||||
)
|
||||
)
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
@@ -406,10 +401,10 @@ class DumpDataUpdate(DumpDataBase):
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
self._update_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
self._update_instruments = (
|
||||
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
|
||||
.set_index([self.symbol_field_name])
|
||||
.to_dict(orient="index")
|
||||
) # type: dict
|
||||
|
||||
# load all csv files
|
||||
@@ -425,10 +420,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
if self._include_fields:
|
||||
_df = pd.read_csv(file_path, usecols=self._include_fields)
|
||||
else:
|
||||
_df = pd.read_csv(file_path)
|
||||
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
@@ -436,7 +428,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
if df:
|
||||
if not df.empty:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
@@ -455,25 +447,27 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
_code = str(_code).upper()
|
||||
_code = fname_to_code(str(_code).lower()).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
self._update_instruments[_code]["end_time"] = _end
|
||||
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
_dt_range["start_time"] = _start
|
||||
_dt_range["end_time"] = _end
|
||||
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
|
||||
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
|
||||
|
||||
for _future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
with tqdm(total=len(futures)) as p_bar:
|
||||
for _future in as_completed(futures):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
p_bar.update()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
@@ -481,7 +475,9 @@ class DumpDataUpdate(DumpDataBase):
|
||||
def dump(self):
|
||||
self.save_calendars(self._new_calendar_list)
|
||||
self._dump_features()
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
|
||||
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
|
||||
df.index.names = [self.symbol_field_name]
|
||||
self.save_instruments(df.reset_index())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -11,7 +11,7 @@ NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
REQUIRES_PYTHON = ">=3.5.0"
|
||||
|
||||
VERSION = "0.6.1.dev"
|
||||
VERSION = "0.6.1.99.dev"
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
Reference in New Issue
Block a user