1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

US stock code supports Windows

This commit is contained in:
zhupr
2020-12-20 23:07:09 +08:00
parent df556532d0
commit 1a1c45981c
11 changed files with 201 additions and 97 deletions

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
__version__ = "0.6.1.dev"
__version__ = "0.6.1.99.dev"
import os
@@ -15,7 +15,7 @@ import platform
import subprocess
from pathlib import Path
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
from .utils import can_use_cache, init_instance_by_config, check_qlib_data
from .workflow.utils import experiment_exit_handler
# init qlib
@@ -88,6 +88,7 @@ def init(default_conf="client", **kwargs):
R.register(qr)
# clean up experiment when python program ends
experiment_exit_handler()
check_qlib_data(C)
def _mount_nfs_uri(C):

View File

@@ -15,14 +15,13 @@ import importlib
import traceback
import numpy as np
import pandas as pd
from pathlib import Path
from multiprocessing import Pool
from .cache import H
from ..config import C
from .ops import *
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
@@ -215,23 +214,6 @@ class InstrumentProvider(abc.ABC):
return cls.LIST
raise ValueError(f"Unknown instrument type {inst}")
def convert_instruments(self, instrument):
_instruments_map = getattr(self, "_instruments_map", None)
if _instruments_map is None:
_df_list = []
# FIXME: each process will read these files
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
_df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
_df_list.append(_df.iloc[:, [0, -1]])
df = pd.concat(_df_list, sort=False)
df["inst"] = df["inst"].astype(str)
df = df.fillna(axis=1, method="ffill")
df = df.sort_values("inst").drop_duplicates(subset=["inst"], keep="first")
df["save_inst"] = df["save_inst"].astype(str)
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
setattr(self, "_instruments_map", _instruments_map)
return _instruments_map.get(instrument, instrument)
class FeatureProvider(abc.ABC):
"""Feature provider class
@@ -590,12 +572,16 @@ class LocalInstrumentProvider(InstrumentProvider):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
df["inst"] = df["inst"].astype(str)
df["save_inst"] = df.loc[:, ["inst", "save_inst"]].fillna(axis=1, method="ffill")["save_inst"].astype(str)
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
@@ -652,7 +638,7 @@ class LocalFeatureProvider(FeatureProvider):
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = Inst.convert_instruments(instrument)
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))

View File

@@ -15,6 +15,10 @@ class TestAutoData(unittest.TestCase):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
name="qlib_data_simple",
region="cn",
interval="1d",
target_dir=provider_uri,
delete_old=False,
)
init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -1,14 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import qlib
import shutil
import zipfile
import requests
import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger
class GetData:
DATASET_VERSION = "v1"
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
def __init__(self, delete_zip_file=False):
"""
@@ -20,13 +27,24 @@ class GetData:
"""
self.delete_zip_file = delete_zip_file
def _download_data(self, file_name: str, target_dir: [Path, str]):
def normalize_dataset_version(self, dataset_version: str = None):
if dataset_version is None:
dataset_version = self.DATASET_VERSION
return dataset_version
def merge_remote_url(self, file_name: str, dataset_version: str = None):
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"
def _download_data(
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
):
target_dir = Path(target_dir).expanduser()
target_dir.mkdir(exist_ok=True, parents=True)
# saved file name
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
target_path = target_dir.joinpath(_target_file_name)
url = f"{self.REMOTE_URL}/{file_name}"
target_path = target_dir.joinpath(file_name)
url = self.merge_remote_url(file_name, dataset_version)
resp = requests.get(url, stream=True)
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
@@ -42,19 +60,59 @@ class GetData:
fp.write(chuck)
p_bar.update(chuck_size)
self._unzip(target_path, target_dir)
self._unzip(target_path, target_dir, delete_old)
if self.delete_zip_file:
target_path.unlike()
def check_dataset(self, file_name: str, dataset_version: str = None):
url = self.merge_remote_url(file_name, dataset_version)
resp = requests.get(url, stream=True)
status = True
if resp.status_code == 404:
status = False
return status
@staticmethod
def _unzip(file_path: Path, target_dir: Path):
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
if delete_old:
logger.warning(
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
)
GetData._delete_qlib_data(target_dir)
logger.info(f"{file_path} unzipping......")
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
@staticmethod
def _delete_qlib_data(file_dir: Path):
logger.info(f"delete {file_dir}")
rm_dirs = []
for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]:
_p = file_dir.joinpath(_name)
if _p.exists():
rm_dirs.append(str(_p.resolve()))
if rm_dirs:
flag = input(
f"Will be deleted: "
f"\n\t{rm_dirs}"
f"\nIf you do not need to delete {file_dir}, please change the <--target_dir>"
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
)
if str(flag) not in ["Y", "y"]:
exit()
for _p in rm_dirs:
logger.warning(f"delete: {_p}")
shutil.rmtree(_p)
def qlib_data(
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
self,
name="qlib_data",
target_dir="~/.qlib/qlib_data/cn_data",
version=None,
interval="1d",
region="cn",
delete_old=True,
):
"""download cn qlib data from remote
@@ -65,20 +123,31 @@ class GetData:
name: str
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
version: str
data version, value from [v0, v1, ..., latest], by default latest
data version, value from [v1, ...], by default None(use script to specify version)
interval: str
data freq, value from [1d], by default 1d
region: str
data region, value from [cn, us], by default cn
delete_old: bool
delete an existing directory, by default True
Examples
---------
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
-------
"""
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
def _get_file_name(v):
return self.QLIB_DATA_NAME.format(
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
)
file_name = _get_file_name(qlib_version)
if not self.check_dataset(file_name, version):
file_name = _get_file_name("latest")
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote

View File

@@ -26,6 +26,7 @@ import pandas as pd
from pathlib import Path
from typing import Union, Tuple
from .. import __version__ as qlib_version
from ..config import C
from ..log import get_module_logger
@@ -643,15 +644,28 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False
return True
def check_qlib_data(qlib_config):
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
for _p in inst_dir.glob("*.txt"):
try:
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
except AssertionError:
raise
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
"""
make the df index sorted
@@ -742,3 +756,36 @@ def load_dataset(path_or_obj):
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f"unsupported file type `{extension}`")
def code_to_fname(code: str):
"""stock code to file name
Parameters
----------
code: str
"""
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
replace_names = ["CON", "PRN", "AUX", "NUL"]
replace_names += [f"COM{i}" for i in range(10)]
replace_names += [f"LPT{i}" for i in range(10)]
prefix = "_qlib_"
if str(code).upper() in replace_names:
code = prefix + str(code)
return code
def fname_to_code(fname: str):
"""file name to stock code
Parameters
----------
fname: str
"""
prefix = "_qlib_"
if fname.startswith(prefix):
fname = fname.lstrip(prefix)
return fname

View File

@@ -20,7 +20,7 @@ pip install -r requirements.txt
### Download data and Normalize data
```bash
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize
```
### Download Data

View File

@@ -18,6 +18,7 @@ from tqdm import tqdm
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
from qlib.utils import code_to_fname
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
@@ -40,7 +41,7 @@ class YahooCollector:
end=None,
interval="1d",
max_workers=4,
max_collector_count=5,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
@@ -55,7 +56,7 @@ class YahooCollector:
max_workers: int
workers, default 4
max_collector_count: int
default 5
default 2
delay: float
time.sleep(delay), default 0
interval: str
@@ -147,11 +148,10 @@ class YahooCollector:
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
with stock_path.open("a") as fp:
df.to_csv(fp, index=False, header=False)
_temp_df = pd.read_csv(stock_path, nrows=0)
df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a")
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
df.to_csv(stock_path, index=False, mode="w")
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
@@ -350,7 +350,7 @@ class YahooCollectorUS(YahooCollector):
pass
def normalize_symbol(self, symbol):
return symbol.upper()
return code_to_fname(symbol).upper()
@property
def _timezone(self):

View File

@@ -14,6 +14,7 @@ import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import fname_to_code, code_to_fname
class DumpDataBase:
@@ -27,7 +28,6 @@ class DumpDataBase:
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
SAVE_INST_FIELD = "save_inst"
UPDATE_MODE = "update"
ALL_MODE = "all"
@@ -45,7 +45,6 @@ class DumpDataBase:
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
inst_prefix: str = "",
):
"""
@@ -73,9 +72,6 @@ class DumpDataBase:
fields not dumped
limit_nums: int
Use when debugging, default None
inst_prefix: str
add a column to the instruments file and record the saved instrument name,
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
@@ -84,7 +80,6 @@ class DumpDataBase:
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self._inst_prefix = inst_prefix.strip()
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
@@ -145,7 +140,7 @@ class DumpDataBase:
return df
def get_symbol_from_file(self, file_path: Path) -> str:
return file_path.name[: -len(self.file_suffix)].strip().lower()
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
@@ -173,7 +168,6 @@ class DumpDataBase:
self.symbol_field_name,
self.INSTRUMENTS_START_FIELD,
self.INSTRUMENTS_END_FIELD,
self.SAVE_INST_FIELD,
],
)
@@ -190,13 +184,11 @@ class DumpDataBase:
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
if self._inst_prefix:
_df_fields.append(self.SAVE_INST_FIELD)
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
lambda x: f"{self._inst_prefix}{x}"
)
instruments_data = instruments_data.loc[:, _df_fields]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(
lambda x: fname_to_code(x.lower()).upper()
)
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
@@ -223,26 +215,26 @@ class DumpDataBase:
logger.warning(f"{features_dir.name} data is None or empty")
return
# align index
_df = self.data_merge_calendar(df, self._calendars_list)
_df = self.data_merge_calendar(df, calendar_list)
# used when creating a bin file
date_index = self.get_datetime_index(_df, calendar_list)
for field in self.get_dump_fields(_df.columns):
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
if field not in _df.columns:
continue
if self._mode == self.UPDATE_MODE:
if bin_path.exists() and self._mode == self.UPDATE_MODE:
# update
with bin_path.open("ab") as fp:
np.array(_df[field]).astype("<f").tofile(fp)
elif self._mode == self.ALL_MODE:
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
else:
raise ValueError(f"{self._mode} cannot support!")
# append; self._mode == self.ALL_MODE or not bin_path.exists()
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
if isinstance(file_or_data, pd.DataFrame):
if file_or_data.empty:
return
code = file_or_data.iloc[0][self.symbol_field_name].lower()
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
df = file_or_data
elif isinstance(file_or_data, Path):
code = self.get_symbol_from_file(file_or_data)
@@ -253,8 +245,7 @@ class DumpDataBase:
logger.warning(f"{code} data is None or empty")
return
# features save dir
code = self._inst_prefix + code if self._inst_prefix else code
features_dir = self._features_dir.joinpath(code)
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
@@ -283,8 +274,6 @@ class DumpDataAll(DumpDataBase):
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
_inst_fields = [symbol.upper(), _begin_time, _end_time]
if self._inst_prefix:
_inst_fields.append(self._inst_prefix + symbol.upper())
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
@@ -323,12 +312,18 @@ class DumpDataFix(DumpDataAll):
def _dump_instruments(self):
logger.info("start dump instruments......")
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
new_stock_files = sorted(
filter(
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
not in self._old_instruments,
self.csv_files,
)
)
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = self.get_symbol_from_file(file_path).upper()
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
@@ -406,10 +401,10 @@ class DumpDataUpdate(DumpDataBase):
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
self._update_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
self._update_instruments = (
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
.set_index([self.symbol_field_name])
.to_dict(orient="index")
) # type: dict
# load all csv files
@@ -425,10 +420,7 @@ class DumpDataUpdate(DumpDataBase):
all_df = []
def _read_csv(file_path: Path):
if self._include_fields:
_df = pd.read_csv(file_path, usecols=self._include_fields)
else:
_df = pd.read_csv(file_path)
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df
@@ -436,7 +428,7 @@ class DumpDataUpdate(DumpDataBase):
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
if df:
if not df.empty:
all_df.append(df)
p_bar.update()
@@ -455,25 +447,27 @@ class DumpDataUpdate(DumpDataBase):
with ProcessPoolExecutor(max_workers=self.works) as executor:
futures = {}
for _code, _df in self._all_data.groupby(self.symbol_field_name):
_code = str(_code).upper()
_code = fname_to_code(str(_code).lower()).upper()
_start, _end = self._get_date(_df, is_begin_end=True)
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
self._update_instruments[_code]["end_time"] = _end
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())
_dt_range["start_time"] = _start
_dt_range["end_time"] = _end
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
for _future in tqdm(as_completed(futures)):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
with tqdm(total=len(futures)) as p_bar:
for _future in as_completed(futures):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
p_bar.update()
logger.info(f"dump bin errors {error_code}")
logger.info("end of features dump.\n")
@@ -481,7 +475,9 @@ class DumpDataUpdate(DumpDataBase):
def dump(self):
self.save_calendars(self._new_calendar_list)
self._dump_features()
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
df.index.names = [self.symbol_field_name]
self.save_instruments(df.reset_index())
if __name__ == "__main__":

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import fire
from qlib.tests.data import GetData

View File

@@ -11,7 +11,7 @@ NAME = "pyqlib"
DESCRIPTION = "A Quantitative-research Platform"
REQUIRES_PYTHON = ">=3.5.0"
VERSION = "0.6.1.dev"
VERSION = "0.6.1.99.dev"
# Detect Cython
try:

View File

@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
def test_0_qlib_data(self):
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
df = D.features(D.instruments("csi300"), self.FIELDS)
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
self.assertFalse(df.dropna().empty, "get qlib data failed")