1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge pull request #938 from SunsetWolf/fix-csi500

Fix csi500
This commit is contained in:
you-n-g
2022-03-11 12:09:22 +08:00
committed by GitHub
4 changed files with 129 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
# CSI300/CSI100 History Companies Collection
# CSI300/CSI100/CSI500 History Companies Collection
## Requirements
@@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m
# parse new companies
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: CSI300, CSI100
# index_name support: CSI300, CSI100, CSI500
# help
python collector.py --help
```

View File

@@ -12,6 +12,8 @@ from pathlib import Path
import fire
import requests
import pandas as pd
import baostock as bs
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
@@ -341,6 +343,121 @@ class CSI100(CSIIndex):
return 2
class CSI500(CSIIndex):
@property
def index_code(self) -> str:
return "000905"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2007-01-15")
@property
def html_table_index(self) -> int:
return 0
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Return
--------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
return self.get_changes_with_history_companies(self.get_history_companies())
def get_history_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
bs.login()
today = pd.datetime.now()
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
ret_list = []
col = ["date", "symbol", "code_name"]
for date in tqdm(date_range, desc="Download CSI500"):
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
result = self.get_data_from_baostock(date)
ret_list.append(result[["date", "symbol"]])
bs.logout()
return pd.concat(ret_list, sort=False)
def get_data_from_baostock(self, date) -> pd.DataFrame:
"""
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
Avoid a large number of parallel data acquisition,
such as 1000 times of concurrent data acquisition, because IP will be blocked
Returns
-------
pd.DataFrame:
date symbol code_name
SH600039 2007-01-15 四川路桥
SH600051 2020-01-15 宁波联合
dtypes:
date: pd.Timestamp
symbol: str
code_name: str
"""
col = ["date", "symbol", "code_name"]
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
return result
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
logger.info("get new companies......")
today = datetime.date.today()
bs.login()
result = self.get_data_from_baostock(today)
bs.logout()
df = result[["date", "symbol"]]
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
df[self.START_DATE_FIELD] = self.bench_start_date
logger.info("end of get new companies.")
return df
def get_instruments(
qlib_dir: str,
index_name: str,

View File

@@ -1,5 +1,8 @@
baostock
logure
fire
requests
pandas
lxml
loguru
tqdm

View File

@@ -27,6 +27,7 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo
CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
"CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
@@ -35,7 +36,6 @@ CALENDAR_BENCH_URL_MAP = {
"IN_ALL": "^NSEI",
}
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
@@ -232,13 +232,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
resp = requests.get(url)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
if len(_symbols) < 8000:
raise ValueError("request error")
return _symbols
@deco_retry
@@ -271,6 +274,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
resp = requests.post(url, json=_parms)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
except Exception as e:
@@ -425,10 +429,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
try:
_result = func(*args, **kwargs)
break
except Exception as e:
logger.warning(f"{func.__name__}: {_i} :{e}")
if _i == _retry:
raise
time.sleep(retry_sleep)
return _result