mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
release-0.5.0 (#1)
* init commit * change the version number * rich the docs&fix cache docs * update index readme * Modify cache class name * Modify sharpe to information_ratio * Modify Group- to Group * add the description of graphical results & fix the backtest docs * fix docs in details * update docs * Update introduction.rst * Update README.md * Update introduction.rst * Update introduction.rst * Update introduction.rst * Update installation.rst * Update installation.rst * Update initialization.rst * Update getdata.rst * Update integration.rst * Update initialization.rst * Update getdata.rst * Update estimator.rst Modify some typos. * Update README.md Modify the typos. * Update initialization.rst * Update data.rst * Update report.rst * Update estimator.rst * Update cumulative_return.py * Update model.rst * Update rank_label.py * Update cumulative_return.py * Update strategy.rst * Update getdata.rst * Update backtest.rst * Update integration.rst * Update getdata.rst * Update introduction.rst * Update introduction.rst * Update README.md * Update report.rst * Update integration.rst Fix typos * Update installation.rst Fix typos * Update getdata.rst * Update initialization.rst Fix typos. * add quick start docs&fix detials * fix estimator docs & fix strategy docs * fix the cahce in data.rst * update documents * Fix Corr && Rsquare * fix data retrival example to csi300 & fix a data bug * fix filter bug * Fix data collector * Modift model args * add the log & fix README.md\quick.rst * add enviroment depend & add intoduction of qlib-server online mode * fix image center fomat & set log_only of docs is True * fix README.md format * update data preparation & readme logo image * get_data support version * Modify analysis names * Modify analysis graph * update report.rst & data.rst * commmit estimator for merge * minimal requirements * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update READEME.md * Update READEME.md * update estimator * Fix doc urls * fix get_data.py docstring * update test_get_data.py * Upate docs * Upate docs * Upate docs Co-authored-by: bxdd <bxddream@gmail.com> Co-authored-by: zhupr <zhu.pengrong@foxmail.com> Co-authored-by: Wendi Li <wendili.academic@qq.com> Co-authored-by: Dingsu Wang <dingsu.wang@gmail.com> Co-authored-by: bxdd <45119470+bxdd@users.noreply.github.com> Co-authored-by: cslwqxx <cslwqxx@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
import bisect
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
@@ -12,16 +13,17 @@ import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
|
||||
|
||||
CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
CSI300_START_DATE = pd.Timestamp("2005-01-01")
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class CSI300:
|
||||
|
||||
@@ -50,12 +52,7 @@ class CSI300:
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
# TODO: get calendar from MSN
|
||||
if self._calendar_list is None:
|
||||
logger.info("get all trading date")
|
||||
value_list = requests.get(CSI300_BENCH_URL).json()["data"]["klines"]
|
||||
self._calendar_list = sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), value_list))
|
||||
return self._calendar_list
|
||||
return get_calendar_list(bench=True)
|
||||
|
||||
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
|
||||
"""get trading date by shift
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# TODO: Support collecting data from MSN
|
||||
103
scripts/data_collector/utils.py
Normal file
103
scripts/data_collector/utils.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import re
|
||||
import requests
|
||||
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
|
||||
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
|
||||
|
||||
def get_hs_calendar_list(bench=False) -> list:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench: bool
|
||||
whether to get the bench calendar list, by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
global _ALL_CALENDAR_LIST
|
||||
global _BENCH_CALENDAR_LIST
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
# TODO: get calendar from MSN
|
||||
if bench:
|
||||
if _BENCH_CALENDAR_LIST is None:
|
||||
_BENCH_CALENDAR_LIST = _get_calendar(CSI300_BENCH_URL)
|
||||
return _BENCH_CALENDAR_LIST
|
||||
|
||||
if _ALL_CALENDAR_LIST is None:
|
||||
_ALL_CALENDAR_LIST = _get_calendar(SH600000_BENCH_URL)
|
||||
return _ALL_CALENDAR_LIST
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
if _HS_SYMBOLS is None:
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
_HS_SYMBOLS = sorted(list(_res))
|
||||
return _HS_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
code, exchange = symbol.split(".")
|
||||
if exchange.lower() in ["sh", "ss"]:
|
||||
res = f"sh{code}"
|
||||
else:
|
||||
res = f"{exchange}{code}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol prefix to sufix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
res = f"{symbol[:-2]}.{symbol[-2:]}"
|
||||
return res.upper() if capital else res.lower()
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -11,45 +10,33 @@ import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from dump_bin import DumpData
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
||||
|
||||
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4):
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=True, max_collector_count=3):
|
||||
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._stock_list = None
|
||||
self.max_workers = max_workers
|
||||
self._asynchronous = asynchronous
|
||||
self._max_collector_count = max_collector_count
|
||||
|
||||
@property
|
||||
def stock_list(self):
|
||||
if self._stock_list is None:
|
||||
self._stock_list = self.get_stock_list()
|
||||
self._stock_list = get_hs_stock_symbols()
|
||||
return self._stock_list
|
||||
|
||||
@staticmethod
|
||||
def get_stock_list() -> list:
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
return sorted(list(_res))
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
@@ -69,19 +56,16 @@ class YahooCollector:
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data
|
||||
def _collector(self, stock_list):
|
||||
|
||||
"""
|
||||
logger.info("start collector yahoo data......")
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
futures = {}
|
||||
p_bar = tqdm(total=len(self.stock_list))
|
||||
for symbols in [
|
||||
self.stock_list[i : i + self.max_workers] for i in range(0, len(self.stock_list), self.max_workers)
|
||||
]:
|
||||
resp = Ticker(symbols, asynchronous=True, max_workers=self.max_workers).history(period="max")
|
||||
p_bar = tqdm(total=len(stock_list))
|
||||
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
|
||||
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
|
||||
period="max"
|
||||
)
|
||||
if isinstance(resp, dict):
|
||||
for symbol, df in resp.items():
|
||||
if isinstance(df, pd.DataFrame):
|
||||
@@ -106,12 +90,26 @@ class YahooCollector:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
return error_symbol
|
||||
|
||||
logger.info(error_symbol)
|
||||
logger.info(len(error_symbol))
|
||||
logger.info(len(self.stock_list))
|
||||
def collector_data(self):
|
||||
"""collector data
|
||||
|
||||
"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
|
||||
# TODO: from MSN
|
||||
logger.info(f"get bench data: csi300(SH000300)......")
|
||||
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
@@ -164,8 +162,14 @@ class Run:
|
||||
def _normalize(file_path: Path):
|
||||
columns = ["open", "close", "high", "low", "volume"]
|
||||
df = pd.read_csv(file_path)
|
||||
df.sort_values("date", inplace=True)
|
||||
df.loc[df["volume"] <= 0, set(df.columns) - {"symbol", "date"}] = np.nan
|
||||
df.set_index("date", inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
|
||||
# using China stock market data calendar
|
||||
df = df.reindex(pd.Index(get_calendar_list()))
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
|
||||
df["factor"] = df["adjclose"] / df["close"]
|
||||
for _col in columns:
|
||||
if _col == "volume":
|
||||
@@ -176,7 +180,8 @@ class Run:
|
||||
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
||||
columns += ["change", "factor"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
df.loc[:, columns + ["date"]].to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
||||
df.index.names = ["date"]
|
||||
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.source_dir.glob("*.csv"))
|
||||
@@ -192,12 +197,13 @@ class Run:
|
||||
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
|
||||
def _adj(file_path: Path):
|
||||
df = pd.read_csv(file_path)
|
||||
df = df.loc[:, ["open", "close", "high", "low", "volume", "change", "factor"]]
|
||||
df = df.loc[:, ["open", "close", "high", "low", "volume", "change", "factor", "date"]]
|
||||
df.sort_values("date", inplace=True)
|
||||
df = df.set_index("date")
|
||||
df = df.loc[df.first_valid_index():]
|
||||
df = df.loc[df.first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
for _col in df.columns:
|
||||
if _col == "volume":
|
||||
@@ -214,7 +220,6 @@ class Run:
|
||||
for _ in worker.map(_adj, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
def dump_data(self):
|
||||
"""dump yahoo data
|
||||
|
||||
|
||||
@@ -52,21 +52,23 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"):
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="v1"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
version: str
|
||||
data version, value from [v0, v1], by default v1
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version v1
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "qlib_data_cn.zip"
|
||||
file_name = f"qlib_data_cn_{version}.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
|
||||
Reference in New Issue
Block a user