mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Support csi100 data collection && Fix data collector
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import bisect
|
||||
from io import BytesIO
|
||||
@@ -18,14 +19,12 @@ sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
CSI300_START_DATE = pd.Timestamp("2005-01-01")
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
|
||||
class CSI300:
|
||||
class CSIIndex:
|
||||
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
@@ -45,6 +44,9 @@ class CSI300:
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._calendar_list = None
|
||||
|
||||
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> list:
|
||||
"""get history trading date
|
||||
@@ -52,7 +54,41 @@ class CSI300:
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
return get_calendar_list(bench=True)
|
||||
return get_calendar_list(bench_code=self.index_name.upper())
|
||||
|
||||
@property
|
||||
def new_companies_url(self):
|
||||
return NEW_COMPANIES_URL.format(index_code=self.index_code)
|
||||
|
||||
@property
|
||||
def changes_url(self):
|
||||
return INDEX_CHANGES_URL
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def html_table_index(self):
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
|
||||
"""get trading date by shift
|
||||
@@ -119,14 +155,18 @@ class CSI300:
|
||||
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
_io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content)
|
||||
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = ["symbol"]
|
||||
_df["type"] = _type
|
||||
@@ -135,9 +175,13 @@ class CSI300:
|
||||
df = pd.concat(tmp)
|
||||
except Exception:
|
||||
df = None
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(resp.content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
@@ -149,31 +193,42 @@ class CSI300:
|
||||
_tmp_df["date"] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _get_change_notices_url() -> list:
|
||||
def _get_change_notices_url(self) -> list:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
resp = requests.get(CSI300_CHANGES_URL)
|
||||
resp = requests.get(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
def _get_new_companies(self):
|
||||
|
||||
logger.info("get new companies")
|
||||
_io = BytesIO(requests.get(NEW_COMPANIES_URL).content)
|
||||
context = requests.get(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(context)
|
||||
_io = BytesIO(context)
|
||||
df = pd.read_excel(_io)
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = ["end_date", "symbol"]
|
||||
df["symbol"] = df["symbol"].map(self.normalize_symbol)
|
||||
df["end_date"] = pd.to_datetime(df["end_date"])
|
||||
df["start_date"] = CSI300_START_DATE
|
||||
df["start_date"] = self.bench_start_date
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
@@ -183,7 +238,7 @@ class CSI300:
|
||||
-------
|
||||
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info("start parse csi300 companies.....")
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = ["symbol", "start_date", "end_date"]
|
||||
changers_df = self._get_changes()
|
||||
new_df = self._get_new_companies()
|
||||
@@ -196,15 +251,65 @@ class CSI300:
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame(
|
||||
[[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"]
|
||||
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
|
||||
)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info("parse csi300 companies finished.")
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
|
||||
|
||||
class CSI300(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi300"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2005-01-01")
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 0
|
||||
|
||||
|
||||
class CSI100(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi100"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2006-05-29")
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
|
||||
|
||||
def parse_instruments(qlib_dir: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
qlib_dir.mkdir(exist_ok=True, parents=True)
|
||||
CSI300(qlib_dir).parse_instruments()
|
||||
CSI100(qlib_dir).parse_instruments()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(CSI300)
|
||||
fire.Fire(parse_instruments)
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import time
|
||||
import pickle
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
@@ -11,39 +14,46 @@ SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_ty
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
|
||||
}
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_hs_calendar_list(bench=False) -> list:
|
||||
def get_hs_calendar_list(bench_code="CSI300") -> list:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench: bool
|
||||
whether to get the bench calendar list, by default False
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
global _ALL_CALENDAR_LIST
|
||||
global _BENCH_CALENDAR_LIST
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
# TODO: get calendar from MSN
|
||||
if bench:
|
||||
if _BENCH_CALENDAR_LIST is None:
|
||||
_BENCH_CALENDAR_LIST = _get_calendar(CSI300_BENCH_URL)
|
||||
return _BENCH_CALENDAR_LIST
|
||||
|
||||
if _ALL_CALENDAR_LIST is None:
|
||||
_ALL_CALENDAR_LIST = _get_calendar(SH600000_BENCH_URL)
|
||||
return _ALL_CALENDAR_LIST
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
return calendar
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
@@ -54,7 +64,8 @@ def get_hs_stock_symbols() -> list:
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
if _HS_SYMBOLS is None:
|
||||
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
@@ -64,7 +75,27 @@ def get_hs_stock_symbols() -> list:
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
_HS_SYMBOLS = sorted(list(_res))
|
||||
return _res
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
symbols = set()
|
||||
_retry = 60
|
||||
# It may take multiple times to get the complete
|
||||
while len(symbols) < MINIMUM_SYMBOLS_NUM:
|
||||
symbols |= _get_symbol()
|
||||
time.sleep(3)
|
||||
|
||||
symbol_cache_path = Path("~/.cache/hs_symbols_cache.pkl").expanduser().resolve()
|
||||
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if symbol_cache_path.exists():
|
||||
with symbol_cache_path.open("rb") as fp:
|
||||
cache_symbols = pickle.load(fp)
|
||||
symbols |= cache_symbols
|
||||
with symbol_cache_path.open("wb") as fp:
|
||||
pickle.dump(symbols, fp)
|
||||
|
||||
_HS_SYMBOLS = sorted(list(symbols))
|
||||
|
||||
return _HS_SYMBOLS
|
||||
|
||||
|
||||
@@ -104,3 +135,7 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
"""
|
||||
res = f"{symbol[:-2]}.{symbol[-2:]}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -19,7 +19,7 @@ sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from dump_bin import DumpData
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
||||
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
MIN_NUMBERS_TRADING = 252 / 4
|
||||
|
||||
|
||||
@@ -130,17 +130,23 @@ class YahooCollector:
|
||||
|
||||
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
|
||||
self.download_csi300_data()
|
||||
self.download_index_data()
|
||||
|
||||
def download_csi300_data(self):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
logger.info(f"get bench data: csi300(SH000300)......")
|
||||
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
|
||||
)
|
||||
)
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
|
||||
|
||||
class Run:
|
||||
@@ -192,7 +198,7 @@ class Run:
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
|
||||
# using China stock market data calendar
|
||||
df = df.reindex(pd.Index(get_calendar_list()))
|
||||
df = df.reindex(pd.Index(get_calendar_list("ALL")))
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
|
||||
@@ -274,8 +280,8 @@ class Run:
|
||||
delay=delay,
|
||||
).collector_data()
|
||||
|
||||
def download_csi300_data(self):
|
||||
YahooCollector(self.source_dir).download_csi300_data()
|
||||
def download_index_data(self):
|
||||
YahooCollector(self.source_dir).download_index_data()
|
||||
|
||||
def download_bench_data(self):
|
||||
"""download bench stock data(SH000300)"""
|
||||
|
||||
Reference in New Issue
Block a user