mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
@@ -1,4 +1,4 @@
|
||||
# CSI300/CSI100 History Companies Collection
|
||||
# CSI300/CSI100/CSI500 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m
|
||||
# parse new companies
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: CSI300, CSI100
|
||||
# index_name support: CSI300, CSI100, CSI500
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
@@ -12,6 +12,8 @@ from pathlib import Path
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
@@ -341,6 +343,121 @@ class CSI100(CSIIndex):
|
||||
return 2
|
||||
|
||||
|
||||
class CSI500(CSIIndex):
|
||||
@property
|
||||
def index_code(self) -> str:
|
||||
return "000905"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2007-01-15")
|
||||
|
||||
@property
|
||||
def html_table_index(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Return
|
||||
--------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
def get_history_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
bs.login()
|
||||
today = pd.datetime.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
ret_list = []
|
||||
col = ["date", "symbol", "code_name"]
|
||||
for date in tqdm(date_range, desc="Download CSI500"):
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
result = self.get_data_from_baostock(date)
|
||||
ret_list.append(result[["date", "symbol"]])
|
||||
bs.logout()
|
||||
return pd.concat(ret_list, sort=False)
|
||||
|
||||
def get_data_from_baostock(self, date) -> pd.DataFrame:
|
||||
"""
|
||||
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
||||
Avoid a large number of parallel data acquisition,
|
||||
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
date symbol code_name
|
||||
SH600039 2007-01-15 四川路桥
|
||||
SH600051 2020-01-15 宁波联合
|
||||
dtypes:
|
||||
date: pd.Timestamp
|
||||
symbol: str
|
||||
code_name: str
|
||||
"""
|
||||
col = ["date", "symbol", "code_name"]
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
return result
|
||||
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
today = datetime.date.today()
|
||||
bs.login()
|
||||
result = self.get_data_from_baostock(today)
|
||||
bs.logout()
|
||||
df = result[["date", "symbol"]]
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
baostock
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
tqdm
|
||||
@@ -27,6 +27,7 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
|
||||
"CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
@@ -35,7 +36,6 @@ CALENDAR_BENCH_URL_MAP = {
|
||||
"IN_ALL": "^NSEI",
|
||||
}
|
||||
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
@@ -232,13 +232,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
|
||||
return _symbols
|
||||
|
||||
@deco_retry
|
||||
@@ -271,6 +274,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
resp = requests.post(url, json=_parms)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
|
||||
except Exception as e:
|
||||
@@ -425,10 +429,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
try:
|
||||
_result = func(*args, **kwargs)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{func.__name__}: {_i} :{e}")
|
||||
if _i == _retry:
|
||||
raise
|
||||
|
||||
time.sleep(retry_sleep)
|
||||
return _result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user