mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
init commit
This commit is contained in:
14
scripts/data_collector/csi/README.md
Normal file
14
scripts/data_collector/csi/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# CSI300 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
|
||||
213
scripts/data_collector/csi/collector.py
Normal file
213
scripts/data_collector/csi/collector.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import bisect
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
|
||||
|
||||
CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
CSI300_START_DATE = pd.Timestamp("2005-01-01")
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class CSI300:
|
||||
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
|
||||
def __init__(self, qlib_dir=None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
"""
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._calendar_list = None
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> list:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
# TODO: get calendar from MSN
|
||||
if self._calendar_list is None:
|
||||
logger.info("get all trading date")
|
||||
value_list = requests.get(CSI300_BENCH_URL).json()["data"]["klines"]
|
||||
self._calendar_list = sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), value_list))
|
||||
return self._calendar_list
|
||||
|
||||
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
left_index = bisect.bisect_left(self.calendar_list, trading_date)
|
||||
try:
|
||||
res = self.calendar_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
def _get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
logger.info("get companies changes......")
|
||||
res = []
|
||||
for _url in self._get_change_notices_url():
|
||||
_df = self._read_change_from_url(_url)
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res)
|
||||
|
||||
@staticmethod
|
||||
def normalize_symbol(symbol):
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
resp = requests.get(url)
|
||||
_text = resp.text
|
||||
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = self._get_trading_date_by_shift(_date, shift=0)
|
||||
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
_io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = ["symbol"]
|
||||
_df["type"] = _type
|
||||
_df["date"] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
except Exception:
|
||||
df = None
|
||||
for _df in pd.read_html(resp.content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df["date"] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
break
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _get_change_notices_url() -> list:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
resp = requests.get(CSI300_CHANGES_URL)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
def _get_new_companies(self):
|
||||
|
||||
logger.info("get new companies")
|
||||
_io = BytesIO(requests.get(NEW_COMPANIES_URL).content)
|
||||
df = pd.read_excel(_io)
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = ["end_date", "symbol"]
|
||||
df["symbol"] = df["symbol"].map(self.normalize_symbol)
|
||||
df["end_date"] = pd.to_datetime(df["end_date"])
|
||||
df["start_date"] = CSI300_START_DATE
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info("start parse csi300 companies.....")
|
||||
instruments_columns = ["symbol", "start_date", "end_date"]
|
||||
changers_df = self._get_changes()
|
||||
new_df = self._get_new_companies()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
|
||||
new_df.loc[
|
||||
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame(
|
||||
[[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"]
|
||||
)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info("parse csi300 companies finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(CSI300)
|
||||
6
scripts/data_collector/csi/requirements.txt
Normal file
6
scripts/data_collector/csi/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
1
scripts/data_collector/msn/README.md
Normal file
1
scripts/data_collector/msn/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# TODO: Support collecting data from MSN
|
||||
38
scripts/data_collector/yahoo/README.md
Normal file
38
scripts/data_collector/yahoo/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Collect Data From Yahoo Finance
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Download data -> Normalize data -> Dump data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
|
||||
### Download Data From Yahoo Finance
|
||||
|
||||
```bash
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
```
|
||||
|
||||
### Normalize Yahoo Finance Data
|
||||
|
||||
```bash
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
```
|
||||
|
||||
### Manual Ajust Yahoo Finance Data
|
||||
|
||||
```bash
|
||||
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
```
|
||||
|
||||
### Dump Yahoo Finance Data
|
||||
|
||||
```bash
|
||||
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
254
scripts/data_collector/yahoo/collector.py
Normal file
254
scripts/data_collector/yahoo/collector.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from dump_bin import DumpData
|
||||
|
||||
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4):
|
||||
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._stock_list = None
|
||||
self.max_workers = max_workers
|
||||
|
||||
@property
|
||||
def stock_list(self):
|
||||
if self._stock_list is None:
|
||||
self._stock_list = self.get_stock_list()
|
||||
return self._stock_list
|
||||
|
||||
@staticmethod
|
||||
def get_stock_list() -> list:
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
return sorted(list(_res))
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
raise ValueError("df is empty")
|
||||
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data
|
||||
|
||||
"""
|
||||
logger.info("start collector yahoo data......")
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
futures = {}
|
||||
p_bar = tqdm(total=len(self.stock_list))
|
||||
for symbols in [
|
||||
self.stock_list[i : i + self.max_workers] for i in range(0, len(self.stock_list), self.max_workers)
|
||||
]:
|
||||
resp = Ticker(symbols, asynchronous=True, max_workers=self.max_workers).history(period="max")
|
||||
if isinstance(resp, dict):
|
||||
for symbol, df in resp.items():
|
||||
if isinstance(df, pd.DataFrame):
|
||||
futures[
|
||||
worker.submit(
|
||||
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
|
||||
)
|
||||
] = symbol
|
||||
else:
|
||||
error_symbol.append(symbol)
|
||||
else:
|
||||
for symbol, df in resp.reset_index().groupby("symbol"):
|
||||
futures[worker.submit(self.save_stock, symbol, df)] = symbol
|
||||
p_bar.update(self.max_workers)
|
||||
p_bar.close()
|
||||
|
||||
with tqdm(total=len(futures.values())) as p_bar:
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
p_bar.update()
|
||||
|
||||
logger.info(error_symbol)
|
||||
logger.info(len(error_symbol))
|
||||
logger.info(len(self.stock_list))
|
||||
|
||||
# TODO: from MSN
|
||||
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False)
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
qlib_dir: str
|
||||
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
self.qlib_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_workers = max_workers
|
||||
|
||||
def normalize_data(self):
|
||||
"""normalize data
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
|
||||
def _normalize(file_path: Path):
|
||||
columns = ["open", "close", "high", "low", "volume"]
|
||||
df = pd.read_csv(file_path)
|
||||
df.sort_values("date", inplace=True)
|
||||
df.loc[df["volume"] <= 0, set(df.columns) - {"symbol", "date"}] = np.nan
|
||||
df["factor"] = df["adjclose"] / df["close"]
|
||||
for _col in columns:
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
||||
columns += ["change", "factor"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
df.loc[:, columns + ["date"]].to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_normalize, file_list):
|
||||
p_bar.update()
|
||||
|
||||
def manual_adj_data(self):
|
||||
"""manual adjust data
|
||||
|
||||
Examples
|
||||
--------
|
||||
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
def _adj(file_path: Path):
|
||||
df = pd.read_csv(file_path)
|
||||
df = df.loc[:, ["open", "close", "high", "low", "volume", "change", "factor"]]
|
||||
df.sort_values("date", inplace=True)
|
||||
df = df.set_index("date")
|
||||
df = df.loc[df.first_valid_index():]
|
||||
_close = df["close"].iloc[0]
|
||||
for _col in df.columns:
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] * _close
|
||||
elif _col != "change":
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.normalize_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_adj, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
def dump_data(self):
|
||||
"""dump yahoo data
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
|
||||
"""
|
||||
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
|
||||
include_fields="close,open,high,low,volume,change,factor"
|
||||
)
|
||||
|
||||
def download_data(self):
|
||||
"""download data from Internet
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
|
||||
"""
|
||||
YahooCollector(self.source_dir, max_workers=self.max_workers).collector_data()
|
||||
|
||||
def collector_data(self):
|
||||
"""download -> normalize -> dump data
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
"""
|
||||
self.download_data()
|
||||
self.normalize_data()
|
||||
self.manual_adj_data()
|
||||
self.dump_data()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
9
scripts/data_collector/yahoo/requirements.txt
Normal file
9
scripts/data_collector/yahoo/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
250
scripts/dump_bin.py
Normal file
250
scripts/dump_bin.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DumpData(object):
|
||||
FILE_SUFFIX = ".csv"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
works: int = None,
|
||||
date_field_name: str = "date",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
backup_dir: str, default None
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "day"
|
||||
transaction frequency
|
||||
works: int, default None
|
||||
number of threads
|
||||
date_field_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
self._backup_qlib_dir(Path(backup_dir).expanduser())
|
||||
|
||||
self.freq = freq
|
||||
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
self.works = works
|
||||
self.date_field_name = date_field_name
|
||||
|
||||
self._calendars_dir = self.qlib_dir.joinpath("calendars")
|
||||
self._features_dir = self.qlib_dir.joinpath("features")
|
||||
self._instruments_dir = self.qlib_dir.joinpath("instruments")
|
||||
|
||||
self._calendars_list = []
|
||||
self._include_fields = ()
|
||||
self._exclude_fields = ()
|
||||
|
||||
def _backup_qlib_dir(self, target_dir: Path):
|
||||
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
|
||||
|
||||
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
return []
|
||||
if is_begin_end:
|
||||
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
|
||||
return df[self.date_field_name].tolist()
|
||||
|
||||
def _get_source_data(self, file_path: Path):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
|
||||
return df
|
||||
|
||||
def _file_to_bin(self, file_path: Path = None):
|
||||
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
# read csv file
|
||||
df = self._get_source_data(file_path)
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
]
|
||||
cal_df.set_index(self.date_field_name, inplace=True)
|
||||
df.set_index(self.date_field_name, inplace=True)
|
||||
r_df = df.reindex(cal_df.index)
|
||||
date_index = self._calendars_list.index(r_df.index.min())
|
||||
for field in (
|
||||
self._include_fields
|
||||
if self._include_fields
|
||||
else set(r_df.columns) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else r_df.columns
|
||||
):
|
||||
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
|
||||
if field not in r_df.columns:
|
||||
continue
|
||||
r = np.hstack([date_index, r_df[field]]).astype("<f")
|
||||
r.tofile(str(bin_path.resolve()))
|
||||
|
||||
@staticmethod
|
||||
def _read_calendar(calendar_path: Path):
|
||||
return sorted(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
def dump_features(
|
||||
self,
|
||||
calendar_path: str = None,
|
||||
include_fields: tuple = None,
|
||||
exclude_fields: tuple = None,
|
||||
):
|
||||
"""dump features
|
||||
|
||||
Parameters
|
||||
---------
|
||||
calendar_path: str
|
||||
calendar path
|
||||
|
||||
include_fields: str
|
||||
dump fields
|
||||
|
||||
exclude_fields: str
|
||||
fields not dumped
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
# dump all stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
# dump one stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
|
||||
"""
|
||||
logger.info("start dump features......")
|
||||
if calendar_path is not None:
|
||||
# read calendar from calendar file
|
||||
self._calendars_list = self._read_calendar(Path(calendar_path))
|
||||
|
||||
if not self._calendars_list:
|
||||
self.dump_calendars()
|
||||
|
||||
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
|
||||
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(self._file_to_bin, self.csv_files):
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump_calendars(self):
|
||||
"""dump calendars
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
logger.info("start dump calendars......")
|
||||
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
all_datetime = set()
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
|
||||
all_datetime = all_datetime | set(temp_datetime)
|
||||
p_bar.update()
|
||||
|
||||
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
|
||||
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
|
||||
logger.info("end of calendars dump.\n")
|
||||
|
||||
def dump_instruments(self):
|
||||
"""dump instruments
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
logger.info("start dump instruments......")
|
||||
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
|
||||
_result_list = []
|
||||
_fun = partial(self._get_date_for_df, is_begin_end=True)
|
||||
with tqdm(total=len(symbol_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as execute:
|
||||
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
|
||||
if res:
|
||||
begin_time = res[0]
|
||||
end_time = res[-1]
|
||||
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
|
||||
p_bar.update()
|
||||
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
|
||||
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
|
||||
logger.info("end of instruments dump.\n")
|
||||
|
||||
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
|
||||
"""dump data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_fields: str
|
||||
dump fields
|
||||
|
||||
exclude_fields: str
|
||||
fields not dumped
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
"""
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self.dump_calendars()
|
||||
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
|
||||
self.dump_instruments()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DumpData)
|
||||
91
scripts/get_data.py
Normal file
91
scripts/get_data.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
import zipfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delete_zip_file : bool, optional
|
||||
Whether to delete the zip file, value from True or False, by default False
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "qlib_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
Reference in New Issue
Block a user