diff --git a/scripts/data_collector/cn_index/README.md b/scripts/data_collector/cn_index/README.md index 82f17eb5c..f9352e711 100644 --- a/scripts/data_collector/cn_index/README.md +++ b/scripts/data_collector/cn_index/README.md @@ -1,4 +1,4 @@ -# CSI300/CSI100 History Companies Collection +# CSI300/CSI100/CSI500 History Companies Collection ## Requirements @@ -15,7 +15,7 @@ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --m # parse new companies python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies -# index_name support: CSI300, CSI100 +# index_name support: CSI300, CSI100, CSI500 # help python collector.py --help ``` diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 6f3eee255..7c23b2394 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -5,6 +5,7 @@ import re import abc import sys import importlib +from tqdm import tqdm from io import BytesIO from typing import List, Iterable from pathlib import Path @@ -12,6 +13,8 @@ from pathlib import Path import fire import requests import pandas as pd +import baostock as bs +from lxml import etree from loguru import logger CUR_DIR = Path(__file__).resolve().parent @@ -44,6 +47,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None): class CSIIndex(IndexBase): + @property def calendar_list(self) -> List[pd.Timestamp]: """get history trading date @@ -70,20 +74,20 @@ class CSIIndex(IndexBase): @abc.abstractmethod def bench_start_date(self) -> pd.Timestamp: """ - Returns - ------- - index start date - """ + Returns + ------- + index start date + """ raise NotImplementedError("rewrite bench_start_date") @property @abc.abstractmethod def index_code(self) -> str: """ - Returns - ------- - index code - """ + Returns + ------- + index code + """ raise NotImplementedError("rewrite index_code") @property @@ -91,10 +95,10 @@ class CSIIndex(IndexBase): def html_table_index(self) -> int: """Which table of changes in html - CSI300: 0 - CSI100: 1 - :return: - """ + CSI300: 0 + CSI100: 1 + :return: + """ raise NotImplementedError() def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: @@ -145,15 +149,15 @@ class CSIIndex(IndexBase): def normalize_symbol(symbol: str) -> str: """ - Parameters - ---------- - symbol: str - symbol + Parameters + ---------- + symbol: str + symbol - Returns - ------- - symbol - """ + Returns + ------- + symbol + """ symbol = f"{int(symbol):06}" return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}" @@ -210,10 +214,10 @@ class CSIIndex(IndexBase): def _read_change_from_url(self, url: str) -> pd.DataFrame: """read change from url - Parameters - ---------- - url : str - change url + Parameters + ---------- + url : str + change url Returns ------- @@ -284,12 +288,12 @@ class CSIIndex(IndexBase): def get_new_companies(self) -> pd.DataFrame: """ - Returns - ------- - pd.DataFrame: + Returns + ------- + pd.DataFrame: - symbol start_date end_date - SH600000 2000-01-01 2099-12-31 + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 dtypes: symbol: str @@ -314,6 +318,7 @@ class CSIIndex(IndexBase): class CSI300(CSIIndex): + @property def index_code(self): return "000300" @@ -324,10 +329,11 @@ class CSI300(CSIIndex): @property def html_table_index(self): - return 1 + return 0 class CSI100(CSIIndex): + @property def index_code(self): return "000903" @@ -338,16 +344,54 @@ class CSI100(CSIIndex): @property def html_table_index(self): - return 2 + return 1 + + +class CSI500(CSIIndex): + + @property + def index_code(self): + return "000905" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2007-01-15") + + @property + def html_table_index(self): + return 0 + + def get_changes(self): + return self.get_changes_with_history_companies(self.get_history_companies()) + + def get_history_companies(self): + """ + Data source:http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 + Avoid a large number of parallel data acquisition, + such as 1000 times of concurrent data acquisition, because IP will be blocked + Returns + ------- + + """ + lg = bs.login() + today = pd.datetime.now() + date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date + ret_list = [] + col = ["date", "symbol", "code_name"] + for date in tqdm(date_range, desc="Download CSI500"): + rs = bs.query_zz500_stocks(date=str(date)) + zz500_stocks = [] + while (rs.error_code == "0") & rs.next(): + zz500_stocks.append(rs.get_row_data()) + result = pd.DataFrame(zz500_stocks, columns=col) + result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper()) + ret_list.append(result[["date", "symbol"]]) + bs.logout() + return pd.concat(ret_list, sort=False) def get_instruments( - qlib_dir: str, - index_name: str, - method: str = "parse_instruments", - freq: str = "day", - request_retry: int = 5, - retry_sleep: int = 3, + qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 ): """ @@ -366,13 +410,13 @@ def get_instruments( retry_sleep: int request sleep, by default 3 - Examples - ------- - # parse instruments - $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + Examples + ------- + # parse instruments + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments - # parse new companies - $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + # parse new companies + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies """ _cur_module = importlib.import_module("data_collector.cn_index.collector") diff --git a/scripts/data_collector/cn_index/requirements.txt b/scripts/data_collector/cn_index/requirements.txt index 1d846b504..fffdc25e9 100644 --- a/scripts/data_collector/cn_index/requirements.txt +++ b/scripts/data_collector/cn_index/requirements.txt @@ -1,5 +1,8 @@ +baostock +logure fire requests pandas lxml loguru +tqdm \ No newline at end of file diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 33e3a047f..eb3c8da17 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -28,15 +28,14 @@ SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/mo CALENDAR_BENCH_URL_MAP = { "CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"), + "CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), "CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"), # NOTE: Use the time series of SH600000 as the sequence of all stocks "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks "US_ALL": "^GSPC", - "IN_ALL": "^NSEI", } - _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None @@ -53,15 +52,15 @@ MINIMUM_SYMBOLS_NUM = 3900 def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: """get SH/SZ history calendar list - Parameters - ---------- - bench_code: str - value from ["CSI300", "CSI500", "ALL", "US_ALL"] + Parameters + ---------- + bench_code: str + value from ["CSI300", "CSI500", "ALL", "US_ALL"] - Returns - ------- - history calendar list - """ + Returns + ------- + history calendar list + """ logger.info(f"get calendar list: {bench_code}......") @@ -178,10 +177,10 @@ def get_calendar_list_by_ratio( def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols - Returns - ------- - stock symbols - """ + Returns + ------- + stock symbols + """ global _HS_SYMBOLS def _get_symbol(): @@ -222,10 +221,10 @@ def get_hs_stock_symbols() -> list: def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: """get US stock symbols - Returns - ------- - stock symbols - """ + Returns + ------- + stock symbols + """ global _US_SYMBOLS @deco_retry @@ -234,13 +233,16 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: resp = requests.get(url) if resp.status_code != 200: raise ValueError("request error") + try: _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] except Exception as e: logger.warning(f"request error: {e}") raise + if len(_symbols) < 8000: raise ValueError("request error") + return _symbols @deco_retry @@ -273,6 +275,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: resp = requests.post(url, json=_parms) if resp.status_code != 200: raise ValueError("request error") + try: _symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()] except Exception as e: @@ -413,16 +416,16 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list: def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix - Parameters - ---------- - symbol: str - symbol - capital : bool - by default True - Returns - ------- + Parameters + ---------- + symbol: str + symbol + capital : bool + by default True + Returns + ------- - """ + """ code, exchange = symbol.split(".") if exchange.lower() in ["sh", "ss"]: res = f"sh{code}" @@ -434,22 +437,24 @@ def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str: """symbol prefix to sufix - Parameters - ---------- - symbol: str - symbol - capital : bool - by default True - Returns - ------- + Parameters + ---------- + symbol: str + symbol + capital : bool + by default True + Returns + ------- - """ + """ res = f"{symbol[:-2]}.{symbol[-2:]}" return res.upper() if capital else res.lower() def deco_retry(retry: int = 5, retry_sleep: int = 3): + def deco_func(func): + @functools.wraps(func) def wrapper(*args, **kwargs): _retry = 5 if callable(retry) else retry @@ -458,10 +463,12 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): try: _result = func(*args, **kwargs) break + except Exception as e: logger.warning(f"{func.__name__}: {_i} :{e}") if _i == _retry: raise + time.sleep(retry_sleep) return _result @@ -473,19 +480,19 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): """get trading date by shift - Parameters - ---------- - trading_list: list - trading calendar list - shift : int - shift, default is 1 + Parameters + ---------- + trading_list: list + trading calendar list + shift : int + shift, default is 1 - trading_date : pd.Timestamp - trading date - Returns - ------- + trading_date : pd.Timestamp + trading date + Returns + ------- - """ + """ trading_date = pd.Timestamp(trading_date) left_index = bisect.bisect_left(trading_list, trading_date) try: