mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
* init commit * change the version number * rich the docs&fix cache docs * update index readme * Modify cache class name * Modify sharpe to information_ratio * Modify Group- to Group * add the description of graphical results & fix the backtest docs * fix docs in details * update docs * Update introduction.rst * Update README.md * Update introduction.rst * Update introduction.rst * Update introduction.rst * Update installation.rst * Update installation.rst * Update initialization.rst * Update getdata.rst * Update integration.rst * Update initialization.rst * Update getdata.rst * Update estimator.rst Modify some typos. * Update README.md Modify the typos. * Update initialization.rst * Update data.rst * Update report.rst * Update estimator.rst * Update cumulative_return.py * Update model.rst * Update rank_label.py * Update cumulative_return.py * Update strategy.rst * Update getdata.rst * Update backtest.rst * Update integration.rst * Update getdata.rst * Update introduction.rst * Update introduction.rst * Update README.md * Update report.rst * Update integration.rst Fix typos * Update installation.rst Fix typos * Update getdata.rst * Update initialization.rst Fix typos. * add quick start docs&fix detials * fix estimator docs & fix strategy docs * fix the cahce in data.rst * update documents * Fix Corr && Rsquare * fix data retrival example to csi300 & fix a data bug * fix filter bug * Fix data collector * Modift model args * add the log & fix README.md\quick.rst * add enviroment depend & add intoduction of qlib-server online mode * fix image center fomat & set log_only of docs is True * fix README.md format * update data preparation & readme logo image * get_data support version * Modify analysis names * Modify analysis graph * update report.rst & data.rst * commmit estimator for merge * minimal requirements * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update READEME.md * Update READEME.md * update estimator * Fix doc urls * fix get_data.py docstring * update test_get_data.py * Upate docs * Upate docs * Upate docs Co-authored-by: bxdd <bxddream@gmail.com> Co-authored-by: zhupr <zhu.pengrong@foxmail.com> Co-authored-by: Wendi Li <wendili.academic@qq.com> Co-authored-by: Dingsu Wang <dingsu.wang@gmail.com> Co-authored-by: bxdd <45119470+bxdd@users.noreply.github.com> Co-authored-by: cslwqxx <cslwqxx@users.noreply.github.com>
260 lines
9.7 KiB
Python
260 lines
9.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import fire
|
|
import requests
|
|
import numpy as np
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
from yahooquery import Ticker
|
|
|
|
CUR_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(CUR_DIR.parent.parent))
|
|
from dump_bin import DumpData
|
|
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
|
|
|
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
|
|
|
|
|
class YahooCollector:
|
|
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=True, max_collector_count=3):
|
|
|
|
self.save_dir = Path(save_dir).expanduser().resolve()
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
self._stock_list = None
|
|
self.max_workers = max_workers
|
|
self._asynchronous = asynchronous
|
|
self._max_collector_count = max_collector_count
|
|
|
|
@property
|
|
def stock_list(self):
|
|
if self._stock_list is None:
|
|
self._stock_list = get_hs_stock_symbols()
|
|
return self._stock_list
|
|
|
|
def save_stock(self, symbol, df: pd.DataFrame):
|
|
"""save stock data to file
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
stock code
|
|
df : pd.DataFrame
|
|
df.columns must contain "symbol" and "datetime"
|
|
"""
|
|
if df.empty:
|
|
raise ValueError("df is empty")
|
|
|
|
symbol_s = symbol.split(".")
|
|
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
|
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
|
df["symbol"] = symbol
|
|
df.to_csv(stock_path, index=False)
|
|
|
|
def _collector(self, stock_list):
|
|
|
|
error_symbol = []
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
|
futures = {}
|
|
p_bar = tqdm(total=len(stock_list))
|
|
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
|
|
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
|
|
period="max"
|
|
)
|
|
if isinstance(resp, dict):
|
|
for symbol, df in resp.items():
|
|
if isinstance(df, pd.DataFrame):
|
|
futures[
|
|
worker.submit(
|
|
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
|
|
)
|
|
] = symbol
|
|
else:
|
|
error_symbol.append(symbol)
|
|
else:
|
|
for symbol, df in resp.reset_index().groupby("symbol"):
|
|
futures[worker.submit(self.save_stock, symbol, df)] = symbol
|
|
p_bar.update(self.max_workers)
|
|
p_bar.close()
|
|
|
|
with tqdm(total=len(futures.values())) as p_bar:
|
|
for future in as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
logger.error(e)
|
|
error_symbol.append(futures[future])
|
|
p_bar.update()
|
|
print(error_symbol)
|
|
logger.info(f"error symbol nums: {len(error_symbol)}")
|
|
logger.info(f"current get symbol nums: {len(stock_list)}")
|
|
return error_symbol
|
|
|
|
def collector_data(self):
|
|
"""collector data
|
|
|
|
"""
|
|
logger.info("start collector yahoo data......")
|
|
stock_list = self.stock_list
|
|
for i in range(self._max_collector_count):
|
|
if not stock_list:
|
|
break
|
|
logger.info(f"getting data: {i+1}")
|
|
stock_list = self._collector(stock_list)
|
|
logger.info(f"{i+1} finish.")
|
|
|
|
# TODO: from MSN
|
|
logger.info(f"get bench data: csi300(SH000300)......")
|
|
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
|
|
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
|
df["date"] = pd.to_datetime(df["date"])
|
|
df = df.astype(float, errors="ignore")
|
|
df["adjclose"] = df["close"]
|
|
df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False)
|
|
|
|
|
|
class Run:
|
|
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str
|
|
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
|
normalize_dir: str
|
|
Directory for normalize data, default "Path(__file__).parent/normalize"
|
|
qlib_dir: str
|
|
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
|
|
max_workers: int
|
|
Concurrent number, default is 4
|
|
"""
|
|
if source_dir is None:
|
|
source_dir = CUR_DIR.joinpath("source")
|
|
self.source_dir = Path(source_dir).expanduser().resolve()
|
|
self.source_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if normalize_dir is None:
|
|
normalize_dir = CUR_DIR.joinpath("normalize")
|
|
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
|
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if qlib_dir is None:
|
|
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
|
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
|
|
self.qlib_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.max_workers = max_workers
|
|
|
|
def normalize_data(self):
|
|
"""normalize data
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
|
|
|
"""
|
|
|
|
def _normalize(file_path: Path):
|
|
columns = ["open", "close", "high", "low", "volume"]
|
|
df = pd.read_csv(file_path)
|
|
df.set_index("date", inplace=True)
|
|
df.index = pd.to_datetime(df.index)
|
|
|
|
# using China stock market data calendar
|
|
df = df.reindex(pd.Index(get_calendar_list()))
|
|
df.sort_index(inplace=True)
|
|
|
|
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
|
|
df["factor"] = df["adjclose"] / df["close"]
|
|
for _col in columns:
|
|
if _col == "volume":
|
|
df[_col] = df[_col] / df["factor"]
|
|
else:
|
|
df[_col] = df[_col] * df["factor"]
|
|
_tmp_series = df["close"].fillna(method="ffill")
|
|
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
|
columns += ["change", "factor"]
|
|
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
|
df.index.names = ["date"]
|
|
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
|
file_list = list(self.source_dir.glob("*.csv"))
|
|
with tqdm(total=len(file_list)) as p_bar:
|
|
for _ in worker.map(_normalize, file_list):
|
|
p_bar.update()
|
|
|
|
def manual_adj_data(self):
|
|
"""manual adjust data
|
|
|
|
Examples
|
|
--------
|
|
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
|
|
|
"""
|
|
|
|
def _adj(file_path: Path):
|
|
df = pd.read_csv(file_path)
|
|
df = df.loc[:, ["open", "close", "high", "low", "volume", "change", "factor", "date"]]
|
|
df.sort_values("date", inplace=True)
|
|
df = df.set_index("date")
|
|
df = df.loc[df.first_valid_index() :]
|
|
_close = df["close"].iloc[0]
|
|
for _col in df.columns:
|
|
if _col == "volume":
|
|
df[_col] = df[_col] * _close
|
|
elif _col != "change":
|
|
df[_col] = df[_col] / _close
|
|
else:
|
|
pass
|
|
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
|
file_list = list(self.normalize_dir.glob("*.csv"))
|
|
with tqdm(total=len(file_list)) as p_bar:
|
|
for _ in worker.map(_adj, file_list):
|
|
p_bar.update()
|
|
|
|
def dump_data(self):
|
|
"""dump yahoo data
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
|
|
|
"""
|
|
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
|
|
include_fields="close,open,high,low,volume,change,factor"
|
|
)
|
|
|
|
def download_data(self):
|
|
"""download data from Internet
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
|
|
|
"""
|
|
YahooCollector(self.source_dir, max_workers=self.max_workers).collector_data()
|
|
|
|
def collector_data(self):
|
|
"""download -> normalize -> dump data
|
|
|
|
Examples
|
|
-------
|
|
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
|
"""
|
|
self.download_data()
|
|
self.normalize_data()
|
|
self.manual_adj_data()
|
|
self.dump_data()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(Run)
|