1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

init commit

This commit is contained in:
Young
2020-09-22 01:43:21 +00:00
parent aa51e5aad3
commit 99ebd87cba
131 changed files with 20218 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# CSI300 History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
```

View File

@@ -0,0 +1,213 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import bisect
from io import BytesIO
from pathlib import Path
import fire
import requests
import pandas as pd
from lxml import etree
from loguru import logger
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
CSI300_START_DATE = pd.Timestamp("2005-01-01")
CUR_DIR = Path(__file__).resolve().parent
class CSI300:
REMOVE = "remove"
ADD = "add"
def __init__(self, qlib_dir=None):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
"""
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self._calendar_list = None
@property
def calendar_list(self) -> list:
"""get history trading date
Returns
-------
"""
# TODO: get calendar from MSN
if self._calendar_list is None:
logger.info("get all trading date")
value_list = requests.get(CSI300_BENCH_URL).json()["data"]["klines"]
self._calendar_list = sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), value_list))
return self._calendar_list
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
"""get trading date by shift
Parameters
----------
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
"""
left_index = bisect.bisect_left(self.calendar_list, trading_date)
try:
res = self.calendar_list[left_index + shift]
except IndexError:
res = trading_date
return res
def _get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
"""
logger.info("get companies changes......")
res = []
for _url in self._get_change_notices_url():
_df = self._read_change_from_url(_url)
res.append(_df)
logger.info("get companies changes finish")
return pd.concat(res)
@staticmethod
def normalize_symbol(symbol):
symbol = f"{int(symbol):06}"
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
def _read_change_from_url(self, url: str) -> pd.DataFrame:
"""read change from url
Parameters
----------
url : str
change url
Returns
-------
"""
resp = requests.get(url)
_text = resp.text
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
if len(date_list) >= 2:
add_date = pd.Timestamp("-".join(date_list[0]))
else:
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
add_date = self._get_trading_date_by_shift(_date, shift=0)
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
logger.info(f"get {add_date} changes")
try:
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
_io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content)
df_map = pd.read_excel(_io, sheet_name=None)
tmp = []
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = ["symbol"]
_df["type"] = _type
_df["date"] = _date
tmp.append(_df)
df = pd.concat(tmp)
except Exception:
df = None
for _df in pd.read_html(resp.content):
if _df.shape[-1] != 4:
continue
tmp = []
for _s, _type, _date in [
(_df.iloc[2:, 0], self.REMOVE, remove_date),
(_df.iloc[2:, 2], self.ADD, add_date),
]:
_tmp_df = pd.DataFrame()
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
_tmp_df["type"] = _type
_tmp_df["date"] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
break
return df
@staticmethod
def _get_change_notices_url() -> list:
"""get change notices url
Returns
-------
"""
resp = requests.get(CSI300_CHANGES_URL)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")
def _get_new_companies(self):
logger.info("get new companies")
_io = BytesIO(requests.get(NEW_COMPANIES_URL).content)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = ["end_date", "symbol"]
df["symbol"] = df["symbol"].map(self.normalize_symbol)
df["end_date"] = pd.to_datetime(df["end_date"])
df["start_date"] = CSI300_START_DATE
return df
def parse_instruments(self):
"""parse csi300.txt
Examples
-------
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info("start parse csi300 companies.....")
instruments_columns = ["symbol", "start_date", "end_date"]
changers_df = self._get_changes()
new_df = self._get_new_companies()
logger.info("parse history companies by changes......")
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
if _row.type == self.ADD:
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
new_df.loc[
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
] = _row.date
else:
_tmp_df = pd.DataFrame(
[[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"]
)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None
)
logger.info("parse csi300 companies finished.")
if __name__ == "__main__":
fire.Fire(CSI300)

View File

@@ -0,0 +1,6 @@
logure
fire
requests
pandas
lxml
loguru

View File

@@ -0,0 +1 @@
# TODO: Support collecting data from MSN

View File

@@ -0,0 +1,38 @@
# Collect Data From Yahoo Finance
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
### Download data -> Normalize data -> Dump data
```bash
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
```
### Download Data From Yahoo Finance
```bash
python collector.py download_data --source_dir ~/.qlib/stock_data/source
```
### Normalize Yahoo Finance Data
```bash
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
```
### Manual Ajust Yahoo Finance Data
```bash
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
```
### Dump Yahoo Finance Data
```bash
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
```

View File

@@ -0,0 +1,254 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import sys
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import fire
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from lxml import etree
from loguru import logger
from yahooquery import Ticker
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from dump_bin import DumpData
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
class YahooCollector:
def __init__(self, save_dir: [str, Path], max_workers=4):
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._stock_list = None
self.max_workers = max_workers
@property
def stock_list(self):
if self._stock_list is None:
self._stock_list = self.get_stock_list()
return self._stock_list
@staticmethod
def get_stock_list() -> list:
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
return sorted(list(_res))
def save_stock(self, symbol, df: pd.DataFrame):
"""save stock data to file
Parameters
----------
symbol: str
stock code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
raise ValueError("df is empty")
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
df.to_csv(stock_path, index=False)
def collector_data(self):
"""collector data
"""
logger.info("start collector yahoo data......")
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
futures = {}
p_bar = tqdm(total=len(self.stock_list))
for symbols in [
self.stock_list[i : i + self.max_workers] for i in range(0, len(self.stock_list), self.max_workers)
]:
resp = Ticker(symbols, asynchronous=True, max_workers=self.max_workers).history(period="max")
if isinstance(resp, dict):
for symbol, df in resp.items():
if isinstance(df, pd.DataFrame):
futures[
worker.submit(
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
)
] = symbol
else:
error_symbol.append(symbol)
else:
for symbol, df in resp.reset_index().groupby("symbol"):
futures[worker.submit(self.save_stock, symbol, df)] = symbol
p_bar.update(self.max_workers)
p_bar.close()
with tqdm(total=len(futures.values())) as p_bar:
for future in as_completed(futures):
try:
future.result()
except Exception as e:
logger.error(e)
error_symbol.append(futures[future])
p_bar.update()
logger.info(error_symbol)
logger.info(len(error_symbol))
logger.info(len(self.stock_list))
# TODO: from MSN
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False)
class Run:
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
qlib_dir: str
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
max_workers: int
Concurrent number, default is 4
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
self.qlib_dir.mkdir(parents=True, exist_ok=True)
self.max_workers = max_workers
def normalize_data(self):
"""normalize data
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
"""
def _normalize(file_path: Path):
columns = ["open", "close", "high", "low", "volume"]
df = pd.read_csv(file_path)
df.sort_values("date", inplace=True)
df.loc[df["volume"] <= 0, set(df.columns) - {"symbol", "date"}] = np.nan
df["factor"] = df["adjclose"] / df["close"]
for _col in columns:
if _col == "volume":
df[_col] = df[_col] / df["factor"]
else:
df[_col] = df[_col] * df["factor"]
_tmp_series = df["close"].fillna(method="ffill")
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
columns += ["change", "factor"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
df.loc[:, columns + ["date"]].to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_normalize, file_list):
p_bar.update()
def manual_adj_data(self):
"""manual adjust data
Examples
--------
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
"""
def _adj(file_path: Path):
df = pd.read_csv(file_path)
df = df.loc[:, ["open", "close", "high", "low", "volume", "change", "factor"]]
df.sort_values("date", inplace=True)
df = df.set_index("date")
df = df.loc[df.first_valid_index():]
_close = df["close"].iloc[0]
for _col in df.columns:
if _col == "volume":
df[_col] = df[_col] * _close
elif _col != "change":
df[_col] = df[_col] / _close
else:
pass
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.normalize_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_adj, file_list):
p_bar.update()
def dump_data(self):
"""dump yahoo data
Examples
---------
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
"""
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
include_fields="close,open,high,low,volume,change,factor"
)
def download_data(self):
"""download data from Internet
Examples
---------
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
"""
YahooCollector(self.source_dir, max_workers=self.max_workers).collector_data()
def collector_data(self):
"""download -> normalize -> dump data
Examples
-------
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
"""
self.download_data()
self.normalize_data()
self.manual_adj_data()
self.dump_data()
if __name__ == "__main__":
fire.Fire(Run)

View File

@@ -0,0 +1,9 @@
logure
fire
requests
numpy
pandas
tqdm
lxml
loguru
yahooquery

250
scripts/dump_bin.py Normal file
View File

@@ -0,0 +1,250 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import shutil
from pathlib import Path
from functools import partial
from concurrent.futures import ThreadPoolExecutor
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
class DumpData(object):
FILE_SUFFIX = ".csv"
def __init__(
self,
csv_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
works: int = None,
date_field_name: str = "date",
):
"""
Parameters
----------
csv_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
backup_dir: str, default None
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
works: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
"""
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
self._backup_qlib_dir(Path(backup_dir).expanduser())
self.freq = freq
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
self.works = works
self.date_field_name = date_field_name
self._calendars_dir = self.qlib_dir.joinpath("calendars")
self._features_dir = self.qlib_dir.joinpath("features")
self._instruments_dir = self.qlib_dir.joinpath("instruments")
self._calendars_list = []
self._include_fields = ()
self._exclude_fields = ()
def _backup_qlib_dir(self, target_dir: Path):
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
df = pd.read_csv(str(file_path.resolve()))
if df.empty or self.date_field_name not in df.columns.tolist():
return []
if is_begin_end:
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
return df[self.date_field_name].tolist()
def _get_source_data(self, file_path: Path):
df = pd.read_csv(str(file_path.resolve()))
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
return df
def _file_to_bin(self, file_path: Path = None):
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
# read csv file
df = self._get_source_data(file_path)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
date_index = self._calendars_list.index(r_df.index.min())
for field in (
self._include_fields
if self._include_fields
else set(r_df.columns) - set(self._exclude_fields)
if self._exclude_fields
else r_df.columns
):
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
if field not in r_df.columns:
continue
r = np.hstack([date_index, r_df[field]]).astype("<f")
r.tofile(str(bin_path.resolve()))
@staticmethod
def _read_calendar(calendar_path: Path):
return sorted(
map(
pd.Timestamp,
pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
)
)
def dump_features(
self,
calendar_path: str = None,
include_fields: tuple = None,
exclude_fields: tuple = None,
):
"""dump features
Parameters
---------
calendar_path: str
calendar path
include_fields: str
dump fields
exclude_fields: str
fields not dumped
Notes
---------
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
# dump all stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
# dump one stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
"""
logger.info("start dump features......")
if calendar_path is not None:
# read calendar from calendar file
self._calendars_list = self._read_calendar(Path(calendar_path))
if not self._calendars_list:
self.dump_calendars()
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(self._file_to_bin, self.csv_files):
p_bar.update()
logger.info("end of features dump.\n")
def dump_calendars(self):
"""dump calendars
Notes
---------
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump calendars......")
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
all_datetime = set()
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
all_datetime = all_datetime | set(temp_datetime)
p_bar.update()
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
self._calendars_dir.mkdir(parents=True, exist_ok=True)
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
logger.info("end of calendars dump.\n")
def dump_instruments(self):
"""dump instruments
Notes
---------
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump instruments......")
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
_result_list = []
_fun = partial(self._get_date_for_df, is_begin_end=True)
with tqdm(total=len(symbol_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as execute:
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
if res:
begin_time = res[0]
end_time = res[-1]
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
p_bar.update()
self._instruments_dir.mkdir(parents=True, exist_ok=True)
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
logger.info("end of instruments dump.\n")
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
"""dump data
Parameters
----------
include_fields: str
dump fields
exclude_fields: str
fields not dumped
Examples
---------
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
"""
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self.dump_calendars()
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
self.dump_instruments()
if __name__ == "__main__":
fire.Fire(DumpData)

91
scripts/get_data.py Normal file
View File

@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import fire
import zipfile
import requests
from tqdm import tqdm
from pathlib import Path
from loguru import logger
class GetData:
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
def __init__(self, delete_zip_file=False):
"""
Parameters
----------
delete_zip_file : bool, optional
Whether to delete the zip file, value from True or False, by default False
"""
self.delete_zip_file = delete_zip_file
def _download_data(self, file_name: str, target_dir: [Path, str]):
target_dir = Path(target_dir).expanduser()
target_dir.mkdir(exist_ok=True, parents=True)
url = f"{self.REMOTE_URL}/{file_name}"
target_path = target_dir.joinpath(file_name)
resp = requests.get(url, stream=True)
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chuck_size = 1024
logger.info(f"{file_name} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chuck in resp.iter_content(chunk_size=chuck_size):
fp.write(chuck)
p_bar.update(chuck_size)
self._unzip(target_path, target_dir)
if self.delete_zip_file:
target_path.unlike()
@staticmethod
def _unzip(file_path: Path, target_dir: Path):
logger.info(f"{file_path} unzipping......")
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"):
"""download cn qlib data from remote
Parameters
----------
target_dir: str
data save directory
Examples
---------
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
-------
"""
file_name = "qlib_data_cn.zip"
self._download_data(file_name, target_dir)
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote
Parameters
----------
target_dir: str
data save directory
Examples
---------
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
-------
"""
file_name = "csv_data_cn.zip"
self._download_data(file_name, target_dir)
if __name__ == "__main__":
fire.Fire(GetData)