1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/scripts/data_collector/cn_index/collector.py
2022-11-28 18:06:29 +08:00

470 lines
16 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import abc
import sys
import datetime
from io import BytesIO
from typing import List, Iterable
from pathlib import Path
import fire
import requests
import pandas as pd
import baostock as bs
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.index import IndexBase
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
from data_collector.utils import get_instruments
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
REQ_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
}
@deco_retry
def retry_request(url: str, method: str = "get", exclude_status: List = None):
if exclude_status is None:
exclude_status = []
method_func = getattr(requests, method)
_resp = method_func(url, headers=REQ_HEADERS)
_status = _resp.status_code
if _status not in exclude_status and _status != 200:
raise ValueError(f"response status: {_status}, url={url}")
return _resp
class CSIIndex(IndexBase):
@property
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
_calendar = getattr(self, "_calendar_list", None)
if not _calendar:
_calendar = get_calendar_list(bench_code=self.index_name.upper())
setattr(self, "_calendar_list", _calendar)
return _calendar
@property
def new_companies_url(self) -> str:
return NEW_COMPANIES_URL.format(index_code=self.index_code)
@property
def changes_url(self) -> str:
return INDEX_CHANGES_URL
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def index_code(self) -> str:
"""
Returns
-------
index code
"""
raise NotImplementedError("rewrite index_code")
@property
def html_table_index(self) -> int:
"""Which table of changes in html
CSI300: 0
CSI100: 1
:return:
"""
raise NotImplementedError("rewrite html_table_index")
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
"""formatting the datetime in an instrument
Parameters
----------
inst_df: pd.DataFrame
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
Returns
-------
"""
if self.freq != "day":
inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime("%Y-%m-%d %H:%M:%S")
)
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime("%Y-%m-%d %H:%M:%S")
)
return inst_df
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
logger.info("get companies changes......")
res = []
for _url in self._get_change_notices_url():
_df = self._read_change_from_url(_url)
if not _df.empty:
res.append(_df)
logger.info("get companies changes finish")
return pd.concat(res, sort=False)
@staticmethod
def normalize_symbol(symbol: str) -> str:
"""
Parameters
----------
symbol: str
symbol
Returns
-------
symbol
"""
symbol = f"{int(symbol):06}"
return f"SH{symbol}" if symbol.startswith("60") or symbol.startswith("688") else f"SZ{symbol}"
def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:
content = retry_request(excel_url, exclude_status=[404]).content
_io = BytesIO(content)
df_map = pd.read_excel(_io, sheet_name=None)
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(content)
tmp = []
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = [self.SYMBOL_FIELD_NAME]
_df["type"] = _type
_df[self.DATE_FIELD_NAME] = _date
tmp.append(_df)
df = pd.concat(tmp)
return df
def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
_tmp_count = 0
for _df in pd.read_html(content):
if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:
continue
_tmp_count += 1
if self.html_table_index + 1 > _tmp_count:
continue
tmp = []
for _s, _type, _date in [
(_df.iloc[2:, 0], self.REMOVE, remove_date),
(_df.iloc[2:, 2], self.ADD, add_date),
]:
_tmp_df = pd.DataFrame()
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
_tmp_df["type"] = _type
_tmp_df[self.DATE_FIELD_NAME] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
df.to_csv(
str(
self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
).resolve()
)
)
break
return df
def _read_change_from_url(self, url: str) -> pd.DataFrame:
"""read change from url
The parameter url is from the _get_change_notices_url method.
Determine the stock add_date/remove_date based on the title.
The response contains three cases:
1.Only excel_url(extract data from excel_url)
2.Both the excel_url and the body text(try to extract data from excel_url first, and then try to extract data from body text)
3.Only body text(extract data from body text)
Parameters
----------
url : str
change url
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
resp = retry_request(url).json()["data"]
title = resp["title"]
if not title.startswith("关于"):
return pd.DataFrame()
if "沪深300" not in title:
return pd.DataFrame()
logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}")
_text = resp["content"]
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
if len(date_list) >= 2:
add_date = pd.Timestamp("-".join(date_list[0]))
else:
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
if "盘后" in _text or "市后" in _text:
add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
excel_url = None
if resp.get("enclosureList", []):
excel_url = resp["enclosureList"][0]["fileUrl"]
else:
excel_url_list = re.findall('.*href="(.*?xls.*?)".*', _text)
if excel_url_list:
excel_url = excel_url_list[0]
if not excel_url.startswith("http"):
excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url
excel_url = f"http://www.csindex.com.cn{excel_url}"
if excel_url:
try:
logger.info(f"get {add_date} changes from the excel, title={title}, excel_url={excel_url}")
df = self._parse_excel(excel_url, add_date, remove_date)
except ValueError:
logger.info(
f"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}"
)
df = self._parse_table(_text, add_date, remove_date)
else:
logger.info(
f"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}"
)
df = self._parse_table(_text, add_date, remove_date)
return df
def _get_change_notices_url(self) -> Iterable[str]:
"""get change notices url
Returns
-------
[url1, url2]
"""
page_num = 1
page_size = 5
data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()
data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json()
for item in data["data"]:
yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}"
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
logger.info("get new companies......")
context = retry_request(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(context)
_io = BytesIO(context)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
df[self.START_DATE_FIELD] = self.bench_start_date
logger.info("end of get new companies.")
return df
class CSI300Index(CSIIndex):
@property
def index_code(self):
return "000300"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2005-01-01")
@property
def html_table_index(self) -> int:
return 0
class CSI100Index(CSIIndex):
@property
def index_code(self):
return "000903"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2006-05-29")
@property
def html_table_index(self) -> int:
return 1
class CSI500Index(CSIIndex):
@property
def index_code(self) -> str:
return "000905"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2007-01-15")
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Return
--------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
return self.get_changes_with_history_companies(self.get_history_companies())
def get_history_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
bs.login()
today = pd.Timestamp.now()
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
ret_list = []
col = ["date", "symbol", "code_name"]
for date in tqdm(date_range, desc="Download CSI500"):
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
result = self.get_data_from_baostock(date)
ret_list.append(result[["date", "symbol"]])
bs.logout()
return pd.concat(ret_list, sort=False)
@staticmethod
def get_data_from_baostock(date) -> pd.DataFrame:
"""
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
Avoid a large number of parallel data acquisition,
such as 1000 times of concurrent data acquisition, because IP will be blocked
Returns
-------
pd.DataFrame:
date symbol code_name
SH600039 2007-01-15 四川路桥
SH600051 2020-01-15 宁波联合
dtypes:
date: pd.Timestamp
symbol: str
code_name: str
"""
col = ["date", "symbol", "code_name"]
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
return result
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
logger.info("get new companies......")
today = pd.Timestamp.now().normalize()
bs.login()
result = self.get_data_from_baostock(today.strftime("%Y-%m-%d"))
bs.logout()
df = result[["date", "symbol"]]
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
df[self.END_DATE_FIELD] = today
df[self.START_DATE_FIELD] = self.bench_start_date
logger.info("end of get new companies.")
return df
if __name__ == "__main__":
fire.Fire(get_instruments)