1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add CSI500 data collector

This commit is contained in:
BigTreei
2020-12-26 16:51:52 +08:00
committed by Linlang Lv (iSoftStone)
parent ec8969a3ae
commit 74cc21fc2c
4 changed files with 146 additions and 92 deletions

View File

@@ -1,4 +1,4 @@
# CSI300/CSI100 History Companies Collection
# CSI300/CSI100/CSI500 History Companies Collection
## Requirements
@@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m
# parse new companies
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: CSI300, CSI100
# index_name support: CSI300, CSI100, CSI500
# help
python collector.py --help
```

View File

@@ -5,6 +5,7 @@ import re
import abc
import sys
import importlib
from tqdm import tqdm
from io import BytesIO
from typing import List, Iterable
from pathlib import Path
@@ -12,6 +13,8 @@ from pathlib import Path
import fire
import requests
import pandas as pd
import baostock as bs
from lxml import etree
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
@@ -44,6 +47,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
class CSIIndex(IndexBase):
@property
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
@@ -70,20 +74,20 @@ class CSIIndex(IndexBase):
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def index_code(self) -> str:
"""
Returns
-------
index code
"""
Returns
-------
index code
"""
raise NotImplementedError("rewrite index_code")
@property
@@ -91,10 +95,10 @@ class CSIIndex(IndexBase):
def html_table_index(self) -> int:
"""Which table of changes in html
CSI300: 0
CSI100: 1
:return:
"""
CSI300: 0
CSI100: 1
:return:
"""
raise NotImplementedError()
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
@@ -145,15 +149,15 @@ class CSIIndex(IndexBase):
def normalize_symbol(symbol: str) -> str:
"""
Parameters
----------
symbol: str
symbol
Parameters
----------
symbol: str
symbol
Returns
-------
symbol
"""
Returns
-------
symbol
"""
symbol = f"{int(symbol):06}"
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
@@ -210,10 +214,10 @@ class CSIIndex(IndexBase):
def _read_change_from_url(self, url: str) -> pd.DataFrame:
"""read change from url
Parameters
----------
url : str
change url
Parameters
----------
url : str
change url
Returns
-------
@@ -284,12 +288,12 @@ class CSIIndex(IndexBase):
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
@@ -314,6 +318,7 @@ class CSIIndex(IndexBase):
class CSI300(CSIIndex):
@property
def index_code(self):
return "000300"
@@ -324,10 +329,11 @@ class CSI300(CSIIndex):
@property
def html_table_index(self):
return 1
return 0
class CSI100(CSIIndex):
@property
def index_code(self):
return "000903"
@@ -338,16 +344,54 @@ class CSI100(CSIIndex):
@property
def html_table_index(self):
return 2
return 1
class CSI500(CSIIndex):
@property
def index_code(self):
return "000905"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2007-01-15")
@property
def html_table_index(self):
return 0
def get_changes(self):
return self.get_changes_with_history_companies(self.get_history_companies())
def get_history_companies(self):
"""
Data sourcehttp://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
Avoid a large number of parallel data acquisition,
such as 1000 times of concurrent data acquisition, because IP will be blocked
Returns
-------
"""
lg = bs.login()
today = pd.datetime.now()
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
ret_list = []
col = ["date", "symbol", "code_name"]
for date in tqdm(date_range, desc="Download CSI500"):
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
ret_list.append(result[["date", "symbol"]])
bs.logout()
return pd.concat(ret_list, sort=False)
def get_instruments(
qlib_dir: str,
index_name: str,
method: str = "parse_instruments",
freq: str = "day",
request_retry: int = 5,
retry_sleep: int = 3,
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
):
"""
@@ -366,13 +410,13 @@ def get_instruments(
retry_sleep: int
request sleep, by default 3
Examples
-------
# parse instruments
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
Examples
-------
# parse instruments
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# parse new companies
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("data_collector.cn_index.collector")

View File

@@ -1,5 +1,8 @@
baostock
logure
fire
requests
pandas
lxml
loguru
tqdm

View File

@@ -28,15 +28,14 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo
CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
"CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
"US_ALL": "^GSPC",
"IN_ALL": "^NSEI",
}
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
@@ -53,15 +52,15 @@ MINIMUM_SYMBOLS_NUM = 3900
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
"""get SH/SZ history calendar list
Parameters
----------
bench_code: str
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
Parameters
----------
bench_code: str
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
Returns
-------
history calendar list
"""
Returns
-------
history calendar list
"""
logger.info(f"get calendar list: {bench_code}......")
@@ -178,10 +177,10 @@ def get_calendar_list_by_ratio(
def get_hs_stock_symbols() -> list:
"""get SH/SZ stock symbols
Returns
-------
stock symbols
"""
Returns
-------
stock symbols
"""
global _HS_SYMBOLS
def _get_symbol():
@@ -222,10 +221,10 @@ def get_hs_stock_symbols() -> list:
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get US stock symbols
Returns
-------
stock symbols
"""
Returns
-------
stock symbols
"""
global _US_SYMBOLS
@deco_retry
@@ -234,13 +233,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
resp = requests.get(url)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
if len(_symbols) < 8000:
raise ValueError("request error")
return _symbols
@deco_retry
@@ -273,6 +275,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
resp = requests.post(url, json=_parms)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
except Exception as e:
@@ -413,16 +416,16 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
"""symbol suffix to prefix
Parameters
----------
symbol: str
symbol
capital : bool
by default True
Returns
-------
Parameters
----------
symbol: str
symbol
capital : bool
by default True
Returns
-------
"""
"""
code, exchange = symbol.split(".")
if exchange.lower() in ["sh", "ss"]:
res = f"sh{code}"
@@ -434,22 +437,24 @@ def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
"""symbol prefix to sufix
Parameters
----------
symbol: str
symbol
capital : bool
by default True
Returns
-------
Parameters
----------
symbol: str
symbol
capital : bool
by default True
Returns
-------
"""
"""
res = f"{symbol[:-2]}.{symbol[-2:]}"
return res.upper() if capital else res.lower()
def deco_retry(retry: int = 5, retry_sleep: int = 3):
def deco_func(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_retry = 5 if callable(retry) else retry
@@ -458,10 +463,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
try:
_result = func(*args, **kwargs)
break
except Exception as e:
logger.warning(f"{func.__name__}: {_i} :{e}")
if _i == _retry:
raise
time.sleep(retry_sleep)
return _result
@@ -473,19 +480,19 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
"""get trading date by shift
Parameters
----------
trading_list: list
trading calendar list
shift : int
shift, default is 1
Parameters
----------
trading_list: list
trading calendar list
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
trading_date : pd.Timestamp
trading date
Returns
-------
"""
"""
trading_date = pd.Timestamp(trading_date)
left_index = bisect.bisect_left(trading_list, trading_date)
try: