mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
* feat: download ibovespa index historic composition ibovespa(ibov) is the largest index in Brazil's stocks exchange. The br_index folder has support for downloading new companies for the current index composition. And has support, as well, for downloading companies from historic composition of ibov index. Partially resolves issue #956 * fix: typo error instead of end_date, it was written end_ate * feat: adds support for downloading stocks historic prices from Brazil's stocks exchange (B3) Together with commit c2f933 it resolves issue #956 * fix: code formatted with black. * wip: Creating code logic for brazils stock market data normalization * docs: brazils stock market data normalization code documentation * fix: code formatted the with black * docs: fixed typo * docs: more info about python version used to generate requirements.txt file * docs: added BeautifulSoup requirements * feat: removed debug prints * feat: added ibov_index_composition variable as a class attribute of IBOVIndex * feat: added increment to generate the four month period used by the ibov index * refactor: Added get_instruments() method inside utils.py for better code usability. Message in the PR request to understand the context of the change In the course of reviewing this PR we found two issues. 1. there are multiple places where the get_instruments() method is used, and we feel that scripts.index.py is the best place for the get_instruments() method to go. 2. data_collector.utils has some very generic stuff put inside it. * refactor: improve brazils stocks download speed The reason to use retry=2 is due to the fact that Yahoo Finance unfortunately does not keep track of the majority of Brazilian stocks. Therefore, the decorator deco_retry with retry argument set to 5 will keep trying to get the stock data 5 times, which makes the code to download Brazilians stocks very slow. In future, this may change, but for now I suggest to leave retry argument to 1 or 2 in order to improve download speed. In order to achieve this code logic an argument called retry_config was added into YahooCollectorBR1d and YahooCollectorBR1min * fix: added __main__ at the bottom of the script * refactor: changed interface inside each index Using partial as `fire.Fire(partial(get_instruments, market_index="br_index" ))` will make the interface easier for the user to execute the script. Then all the collector.py CLI in each folder can remove a redundant arguments. * refactor: implemented class interface retry into YahooCollectorBR * docs: added BR as a possible region into the documentation * refactor: make retry attribute part of the interface This way we don't have to use hasattr to access the retry attribute as previously done
464 lines
15 KiB
Python
464 lines
15 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import re
|
|
import abc
|
|
import sys
|
|
import importlib
|
|
from io import BytesIO
|
|
from typing import List, Iterable
|
|
from pathlib import Path
|
|
|
|
import fire
|
|
import requests
|
|
import pandas as pd
|
|
import baostock as bs
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
|
|
CUR_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(CUR_DIR.parent.parent))
|
|
|
|
from data_collector.index import IndexBase
|
|
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
|
from data_collector.utils import get_instruments
|
|
|
|
|
|
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
|
|
|
|
|
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
|
|
|
|
REQ_HEADERS = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
|
}
|
|
|
|
|
|
@deco_retry
|
|
def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
|
if exclude_status is None:
|
|
exclude_status = []
|
|
method_func = getattr(requests, method)
|
|
_resp = method_func(url, headers=REQ_HEADERS)
|
|
_status = _resp.status_code
|
|
if _status not in exclude_status and _status != 200:
|
|
raise ValueError(f"response status: {_status}, url={url}")
|
|
return _resp
|
|
|
|
|
|
class CSIIndex(IndexBase):
|
|
@property
|
|
def calendar_list(self) -> List[pd.Timestamp]:
|
|
"""get history trading date
|
|
|
|
Returns
|
|
-------
|
|
calendar list
|
|
"""
|
|
_calendar = getattr(self, "_calendar_list", None)
|
|
if not _calendar:
|
|
_calendar = get_calendar_list(bench_code=self.index_name.upper())
|
|
setattr(self, "_calendar_list", _calendar)
|
|
return _calendar
|
|
|
|
@property
|
|
def new_companies_url(self) -> str:
|
|
return NEW_COMPANIES_URL.format(index_code=self.index_code)
|
|
|
|
@property
|
|
def changes_url(self) -> str:
|
|
return INDEX_CHANGES_URL
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
"""
|
|
Returns
|
|
-------
|
|
index start date
|
|
"""
|
|
raise NotImplementedError("rewrite bench_start_date")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def index_code(self) -> str:
|
|
"""
|
|
Returns
|
|
-------
|
|
index code
|
|
"""
|
|
raise NotImplementedError("rewrite index_code")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def html_table_index(self) -> int:
|
|
"""Which table of changes in html
|
|
|
|
CSI300: 0
|
|
CSI100: 1
|
|
:return:
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
|
"""formatting the datetime in an instrument
|
|
|
|
Parameters
|
|
----------
|
|
inst_df: pd.DataFrame
|
|
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
|
|
|
Returns
|
|
-------
|
|
|
|
"""
|
|
if self.freq != "day":
|
|
inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(
|
|
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime("%Y-%m-%d %H:%M:%S")
|
|
)
|
|
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
|
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime("%Y-%m-%d %H:%M:%S")
|
|
)
|
|
return inst_df
|
|
|
|
def get_changes(self) -> pd.DataFrame:
|
|
"""get companies changes
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
symbol date type
|
|
SH600000 2019-11-11 add
|
|
SH600000 2020-11-10 remove
|
|
dtypes:
|
|
symbol: str
|
|
date: pd.Timestamp
|
|
type: str, value from ["add", "remove"]
|
|
"""
|
|
logger.info("get companies changes......")
|
|
res = []
|
|
for _url in self._get_change_notices_url():
|
|
_df = self._read_change_from_url(_url)
|
|
if not _df.empty:
|
|
res.append(_df)
|
|
logger.info("get companies changes finish")
|
|
return pd.concat(res, sort=False)
|
|
|
|
@staticmethod
|
|
def normalize_symbol(symbol: str) -> str:
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
symbol
|
|
|
|
Returns
|
|
-------
|
|
symbol
|
|
"""
|
|
symbol = f"{int(symbol):06}"
|
|
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
|
|
|
def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:
|
|
content = retry_request(excel_url, exclude_status=[404]).content
|
|
_io = BytesIO(content)
|
|
df_map = pd.read_excel(_io, sheet_name=None)
|
|
with self.cache_dir.joinpath(
|
|
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
|
).open("wb") as fp:
|
|
fp.write(content)
|
|
tmp = []
|
|
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
|
_df = df_map[_s_name]
|
|
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
|
_df = _df.applymap(self.normalize_symbol)
|
|
_df.columns = [self.SYMBOL_FIELD_NAME]
|
|
_df["type"] = _type
|
|
_df[self.DATE_FIELD_NAME] = _date
|
|
tmp.append(_df)
|
|
df = pd.concat(tmp)
|
|
return df
|
|
|
|
def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:
|
|
df = pd.DataFrame()
|
|
_tmp_count = 0
|
|
for _df in pd.read_html(content):
|
|
if _df.shape[-1] != 4:
|
|
continue
|
|
_tmp_count += 1
|
|
if self.html_table_index + 1 > _tmp_count:
|
|
continue
|
|
tmp = []
|
|
for _s, _type, _date in [
|
|
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
|
(_df.iloc[2:, 2], self.ADD, add_date),
|
|
]:
|
|
_tmp_df = pd.DataFrame()
|
|
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
|
_tmp_df["type"] = _type
|
|
_tmp_df[self.DATE_FIELD_NAME] = _date
|
|
tmp.append(_tmp_df)
|
|
df = pd.concat(tmp)
|
|
df.to_csv(
|
|
str(
|
|
self.cache_dir.joinpath(
|
|
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
|
).resolve()
|
|
)
|
|
)
|
|
break
|
|
return df
|
|
|
|
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
|
"""read change from url
|
|
|
|
Parameters
|
|
----------
|
|
url : str
|
|
change url
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
symbol date type
|
|
SH600000 2019-11-11 add
|
|
SH600000 2020-11-10 remove
|
|
dtypes:
|
|
symbol: str
|
|
date: pd.Timestamp
|
|
type: str, value from ["add", "remove"]
|
|
"""
|
|
resp = retry_request(url).json()["data"]
|
|
title = resp["title"]
|
|
if not title.startswith("关于"):
|
|
return pd.DataFrame()
|
|
if "沪深300" not in title:
|
|
return pd.DataFrame()
|
|
|
|
logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}")
|
|
_text = resp["content"]
|
|
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
|
if len(date_list) >= 2:
|
|
add_date = pd.Timestamp("-".join(date_list[0]))
|
|
else:
|
|
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
|
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
|
if "盘后" in _text or "市后" in _text:
|
|
add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)
|
|
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
|
|
|
excel_url = None
|
|
if resp.get("enclosureList", []):
|
|
excel_url = resp["enclosureList"][0]["fileUrl"]
|
|
else:
|
|
excel_url_list = re.findall('.*href="(.*?xls.*?)".*', _text)
|
|
if excel_url_list:
|
|
excel_url = excel_url_list[0]
|
|
if not excel_url.startswith("http"):
|
|
excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url
|
|
excel_url = f"http://www.csindex.com.cn{excel_url}"
|
|
if excel_url:
|
|
logger.info(f"get {add_date} changes from excel, title={title}, excel_url={excel_url}")
|
|
try:
|
|
df = self._parse_excel(excel_url, add_date, remove_date)
|
|
except ValueError:
|
|
logger.warning(f"error downloading file: {excel_url}, will parse the table from the content")
|
|
df = self._parse_table(_text, add_date, remove_date)
|
|
else:
|
|
logger.info(f"get {add_date} changes from url content, title={title}")
|
|
df = self._parse_table(_text, add_date, remove_date)
|
|
return df
|
|
|
|
def _get_change_notices_url(self) -> Iterable[str]:
|
|
"""get change notices url
|
|
|
|
Returns
|
|
-------
|
|
[url1, url2]
|
|
"""
|
|
page_num = 1
|
|
page_size = 5
|
|
data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()
|
|
data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json()
|
|
for item in data["data"]:
|
|
yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}"
|
|
|
|
def get_new_companies(self) -> pd.DataFrame:
|
|
"""
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
|
|
symbol start_date end_date
|
|
SH600000 2000-01-01 2099-12-31
|
|
|
|
dtypes:
|
|
symbol: str
|
|
start_date: pd.Timestamp
|
|
end_date: pd.Timestamp
|
|
"""
|
|
logger.info("get new companies......")
|
|
context = retry_request(self.new_companies_url).content
|
|
with self.cache_dir.joinpath(
|
|
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
|
).open("wb") as fp:
|
|
fp.write(context)
|
|
_io = BytesIO(context)
|
|
df = pd.read_excel(_io)
|
|
df = df.iloc[:, [0, 4]]
|
|
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
|
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
|
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
|
df[self.START_DATE_FIELD] = self.bench_start_date
|
|
logger.info("end of get new companies.")
|
|
return df
|
|
|
|
|
|
class CSI300Index(CSIIndex):
|
|
@property
|
|
def index_code(self):
|
|
return "000300"
|
|
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2005-01-01")
|
|
|
|
@property
|
|
def html_table_index(self):
|
|
return 1
|
|
|
|
|
|
class CSI100(CSIIndex):
|
|
@property
|
|
def index_code(self):
|
|
return "000903"
|
|
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2006-05-29")
|
|
|
|
@property
|
|
def html_table_index(self):
|
|
return 2
|
|
|
|
|
|
class CSI500(CSIIndex):
|
|
@property
|
|
def index_code(self) -> str:
|
|
return "000905"
|
|
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2007-01-15")
|
|
|
|
@property
|
|
def html_table_index(self) -> int:
|
|
return 0
|
|
|
|
def get_changes(self) -> pd.DataFrame:
|
|
"""get companies changes
|
|
|
|
Return
|
|
--------
|
|
pd.DataFrame:
|
|
symbol date type
|
|
SH600000 2019-11-11 add
|
|
SH600000 2020-11-10 remove
|
|
dtypes:
|
|
symbol: str
|
|
date: pd.Timestamp
|
|
type: str, value from ["add", "remove"]
|
|
"""
|
|
return self.get_changes_with_history_companies(self.get_history_companies())
|
|
|
|
def get_history_companies(self) -> pd.DataFrame:
|
|
"""
|
|
|
|
Returns
|
|
-------
|
|
|
|
pd.DataFrame:
|
|
symbol date type
|
|
SH600000 2019-11-11 add
|
|
SH600000 2020-11-10 remove
|
|
dtypes:
|
|
symbol: str
|
|
date: pd.Timestamp
|
|
type: str, value from ["add", "remove"]
|
|
"""
|
|
bs.login()
|
|
today = pd.datetime.now()
|
|
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
|
ret_list = []
|
|
col = ["date", "symbol", "code_name"]
|
|
for date in tqdm(date_range, desc="Download CSI500"):
|
|
rs = bs.query_zz500_stocks(date=str(date))
|
|
zz500_stocks = []
|
|
while (rs.error_code == "0") & rs.next():
|
|
zz500_stocks.append(rs.get_row_data())
|
|
result = pd.DataFrame(zz500_stocks, columns=col)
|
|
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
|
result = self.get_data_from_baostock(date)
|
|
ret_list.append(result[["date", "symbol"]])
|
|
bs.logout()
|
|
return pd.concat(ret_list, sort=False)
|
|
|
|
def get_data_from_baostock(self, date) -> pd.DataFrame:
|
|
"""
|
|
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
|
Avoid a large number of parallel data acquisition,
|
|
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
date symbol code_name
|
|
SH600039 2007-01-15 四川路桥
|
|
SH600051 2020-01-15 宁波联合
|
|
dtypes:
|
|
date: pd.Timestamp
|
|
symbol: str
|
|
code_name: str
|
|
"""
|
|
col = ["date", "symbol", "code_name"]
|
|
rs = bs.query_zz500_stocks(date=str(date))
|
|
zz500_stocks = []
|
|
while (rs.error_code == "0") & rs.next():
|
|
zz500_stocks.append(rs.get_row_data())
|
|
result = pd.DataFrame(zz500_stocks, columns=col)
|
|
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
|
return result
|
|
|
|
def get_new_companies(self) -> pd.DataFrame:
|
|
"""
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
|
|
symbol start_date end_date
|
|
SH600000 2000-01-01 2099-12-31
|
|
|
|
dtypes:
|
|
symbol: str
|
|
start_date: pd.Timestamp
|
|
end_date: pd.Timestamp
|
|
"""
|
|
logger.info("get new companies......")
|
|
today = datetime.date.today()
|
|
bs.login()
|
|
result = self.get_data_from_baostock(today)
|
|
bs.logout()
|
|
df = result[["date", "symbol"]]
|
|
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
|
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
|
df[self.START_DATE_FIELD] = self.bench_start_date
|
|
logger.info("end of get new companies.")
|
|
return df
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(get_instruments)
|