mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add CSI500 data collector
This commit is contained in:
committed by
Linlang Lv (iSoftStone)
parent
ec8969a3ae
commit
74cc21fc2c
@@ -1,4 +1,4 @@
|
||||
# CSI300/CSI100 History Companies Collection
|
||||
# CSI300/CSI100/CSI500 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m
|
||||
# parse new companies
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: CSI300, CSI100
|
||||
# index_name support: CSI300, CSI100, CSI500
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
import abc
|
||||
import sys
|
||||
import importlib
|
||||
from tqdm import tqdm
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
@@ -12,6 +13,8 @@ from pathlib import Path
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
@@ -44,6 +47,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
@@ -70,20 +74,20 @@ class CSIIndex(IndexBase):
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self) -> str:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
@@ -91,10 +95,10 @@ class CSIIndex(IndexBase):
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -145,15 +149,15 @@ class CSIIndex(IndexBase):
|
||||
def normalize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
@@ -210,10 +214,10 @@ class CSIIndex(IndexBase):
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -284,12 +288,12 @@ class CSIIndex(IndexBase):
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
@@ -314,6 +318,7 @@ class CSIIndex(IndexBase):
|
||||
|
||||
|
||||
class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
@@ -324,10 +329,11 @@ class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
@@ -338,16 +344,54 @@ class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
|
||||
class CSI500(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000905"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2007-01-15")
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 0
|
||||
|
||||
def get_changes(self):
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
def get_history_companies(self):
|
||||
"""
|
||||
Data source:http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
||||
Avoid a large number of parallel data acquisition,
|
||||
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
lg = bs.login()
|
||||
today = pd.datetime.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
ret_list = []
|
||||
col = ["date", "symbol", "code_name"]
|
||||
for date in tqdm(date_range, desc="Download CSI500"):
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
ret_list.append(result[["date", "symbol"]])
|
||||
bs.logout()
|
||||
return pd.concat(ret_list, sort=False)
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
method: str = "parse_instruments",
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -366,13 +410,13 @@ def get_instruments(
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("data_collector.cn_index.collector")
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
baostock
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
tqdm
|
||||
@@ -28,15 +28,14 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
|
||||
"CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
"IN_ALL": "^NSEI",
|
||||
}
|
||||
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
@@ -53,15 +52,15 @@ MINIMUM_SYMBOLS_NUM = 3900
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
@@ -178,10 +177,10 @@ def get_calendar_list_by_ratio(
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
|
||||
def _get_symbol():
|
||||
@@ -222,10 +221,10 @@ def get_hs_stock_symbols() -> list:
|
||||
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get US stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
@@ -234,13 +233,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
|
||||
return _symbols
|
||||
|
||||
@deco_retry
|
||||
@@ -273,6 +275,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
resp = requests.post(url, json=_parms)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
|
||||
except Exception as e:
|
||||
@@ -413,16 +416,16 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
code, exchange = symbol.split(".")
|
||||
if exchange.lower() in ["sh", "ss"]:
|
||||
res = f"sh{code}"
|
||||
@@ -434,22 +437,24 @@ def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol prefix to sufix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
res = f"{symbol[:-2]}.{symbol[-2:]}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
|
||||
def deco_func(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_retry = 5 if callable(retry) else retry
|
||||
@@ -458,10 +463,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
try:
|
||||
_result = func(*args, **kwargs)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{func.__name__}: {_i} :{e}")
|
||||
if _i == _retry:
|
||||
raise
|
||||
|
||||
time.sleep(retry_sleep)
|
||||
return _result
|
||||
|
||||
@@ -473,19 +480,19 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
trading_date = pd.Timestamp(trading_date)
|
||||
left_index = bisect.bisect_left(trading_list, trading_date)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user