diff --git a/scripts/data_collector/cn_index/README.md b/scripts/data_collector/cn_index/README.md index 82f17eb5c..f9352e711 100644 --- a/scripts/data_collector/cn_index/README.md +++ b/scripts/data_collector/cn_index/README.md @@ -1,4 +1,4 @@ -# CSI300/CSI100 History Companies Collection +# CSI300/CSI100/CSI500 History Companies Collection ## Requirements @@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m # parse new companies python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies -# index_name support: CSI300, CSI100 +# index_name support: CSI300, CSI100, CSI500 # help python collector.py --help ``` diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 6f3eee255..e5970c256 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -12,6 +12,8 @@ from pathlib import Path import fire import requests import pandas as pd +import baostock as bs +from tqdm import tqdm from loguru import logger CUR_DIR = Path(__file__).resolve().parent @@ -341,6 +343,121 @@ class CSI100(CSIIndex): return 2 +class CSI500(CSIIndex): + @property + def index_code(self) -> str: + return "000905" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2007-01-15") + + @property + def html_table_index(self) -> int: + return 0 + + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Return + -------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + return self.get_changes_with_history_companies(self.get_history_companies()) + + def get_history_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + bs.login() + today = pd.datetime.now() + date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date + ret_list = [] + col = ["date", "symbol", "code_name"] + for date in tqdm(date_range, desc="Download CSI500"): + rs = bs.query_zz500_stocks(date=str(date)) + zz500_stocks = [] + while (rs.error_code == "0") & rs.next(): + zz500_stocks.append(rs.get_row_data()) + result = pd.DataFrame(zz500_stocks, columns=col) + result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper()) + result = self.get_data_from_baostock(date) + ret_list.append(result[["date", "symbol"]]) + bs.logout() + return pd.concat(ret_list, sort=False) + + def get_data_from_baostock(self, date) -> pd.DataFrame: + """ + Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 + Avoid a large number of parallel data acquisition, + such as 1000 times of concurrent data acquisition, because IP will be blocked + + Returns + ------- + pd.DataFrame: + date symbol code_name + SH600039 2007-01-15 四川路桥 + SH600051 2020-01-15 宁波联合 + dtypes: + date: pd.Timestamp + symbol: str + code_name: str + """ + col = ["date", "symbol", "code_name"] + rs = bs.query_zz500_stocks(date=str(date)) + zz500_stocks = [] + while (rs.error_code == "0") & rs.next(): + zz500_stocks.append(rs.get_row_data()) + result = pd.DataFrame(zz500_stocks, columns=col) + result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper()) + return result + + def get_new_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 + + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + logger.info("get new companies......") + today = datetime.date.today() + bs.login() + result = self.get_data_from_baostock(today) + bs.logout() + df = result[["date", "symbol"]] + df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME] + df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str)) + df[self.START_DATE_FIELD] = self.bench_start_date + logger.info("end of get new companies.") + return df + + def get_instruments( qlib_dir: str, index_name: str, diff --git a/scripts/data_collector/cn_index/requirements.txt b/scripts/data_collector/cn_index/requirements.txt index 1d846b504..fffdc25e9 100644 --- a/scripts/data_collector/cn_index/requirements.txt +++ b/scripts/data_collector/cn_index/requirements.txt @@ -1,5 +1,8 @@ +baostock +logure fire requests pandas lxml loguru +tqdm \ No newline at end of file diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 19131ec29..1814b75ea 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -27,6 +27,7 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo CALENDAR_BENCH_URL_MAP = { "CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"), + "CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), "CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"), # NOTE: Use the time series of SH600000 as the sequence of all stocks "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), @@ -35,7 +36,6 @@ CALENDAR_BENCH_URL_MAP = { "IN_ALL": "^NSEI", } - _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None @@ -232,13 +232,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: resp = requests.get(url) if resp.status_code != 200: raise ValueError("request error") + try: _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] except Exception as e: logger.warning(f"request error: {e}") raise + if len(_symbols) < 8000: raise ValueError("request error") + return _symbols @deco_retry @@ -271,6 +274,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: resp = requests.post(url, json=_parms) if resp.status_code != 200: raise ValueError("request error") + try: _symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()] except Exception as e: @@ -425,10 +429,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): try: _result = func(*args, **kwargs) break + except Exception as e: logger.warning(f"{func.__name__}: {_i} :{e}") if _i == _retry: raise + time.sleep(retry_sleep) return _result