From fb606ec8747f2eeb2a9f96cf50be9554ad5211e2 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 28 Jan 2026 14:39:07 +0800 Subject: [PATCH] fix: strategies for enhancing crawlers --- scripts/data_collector/utils.py | 200 +++++++++++++++++++++++++++----- 1 file changed, 169 insertions(+), 31 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 7b0c05768..d595d7ff6 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -3,10 +3,12 @@ import re import copy +import datetime import importlib import time import bisect import pickle +import random import requests import functools from pathlib import Path @@ -23,7 +25,7 @@ from bs4 import BeautifulSoup HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" -CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231" +CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={start}&end={end}" SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/monthList?month={month}&random={random}" CALENDAR_BENCH_URL_MAP = { @@ -38,6 +40,24 @@ CALENDAR_BENCH_URL_MAP = { "BR_ALL": "^BVSP", } +CHROME_UA_POOL = [ + # Windows + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/121.0.6167.85 Safari/537.36", + # macOS + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/121.0.0.0 Safari/537.36", + # Linux + "Mozilla/5.0 (X11; Linux x86_64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36", +] + _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None @@ -51,6 +71,16 @@ _CALENDAR_MAP = {} MINIMUM_SYMBOLS_NUM = 3900 +def build_headers(): + return { + "User-Agent": random.choice(CHROME_UA_POOL), + "Accept": "application/json,text/plain,*/*", + "Accept-Language": "zh-CN,zh;q=0.9", + "Referer": "https://quote.eastmoney.com/", + "Connection": "keep-alive", + } + + def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: """get SH/SZ history calendar list @@ -67,16 +97,58 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: logger.info(f"get calendar list: {bench_code}......") def _get_calendar(url): - _value_list = requests.get(url, timeout=None).json()["data"]["klines"] - return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) + session = requests.Session() + session.headers.update(build_headers()) + current_datetime = datetime.datetime.now() + cur_year = current_datetime.year + res_list = [] + for per_year in range(2000, cur_year + 1): + start = f"{per_year}0101" + end = f"{per_year}1231" + formatted_url = url.format(start=start, end=end) + try: + resp = session.get(formatted_url, timeout=10) + resp.raise_for_status() + payload = resp.json() + data = payload.get("data") + if not data or "klines" not in data: + continue + + klines = data["klines"] + res_list.extend(pd.Timestamp(x.split(",")[0]) for x in klines) + + except requests.RequestException as e: + continue + + time.sleep(random.uniform(0.5, 1.2)) + + return sorted(set(res_list)) + + # _value_list = requests.get(url, timeout=None).json()["data"]["klines"] + # return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"): + if ( + bench_code.startswith("US_") + or bench_code.startswith("IN_") + or bench_code.startswith("BR_") + ): print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code])) - print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")) - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") - calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + print( + Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + ) + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + calendar = ( + df.index.get_level_values(level="date") + .map(pd.Timestamp) + .unique() + .tolist() + ) else: if bench_code.upper() == "ALL": import akshare as ak # pylint: disable=C0415 @@ -85,7 +157,10 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: trade_date_list = trade_date_df["trade_date"].tolist() trade_date_list = [pd.Timestamp(d) for d in trade_date_list] dates = pd.DatetimeIndex(trade_date_list) - filtered_dates = dates[(dates >= "2000-01-04") & (dates <= pd.Timestamp.today().normalize())] + filtered_dates = dates[ + (dates >= "2000-01-04") + & (dates <= pd.Timestamp.today().normalize()) + ] calendar = filtered_dates.tolist() else: calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) @@ -150,7 +225,9 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count} + _dict_count_founding = { + date: _number_all_funds for date in _dict_count_trade + } # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -158,7 +235,9 @@ def get_calendar_list_by_ratio( _dict_count_founding[date] -= 1 calendar = [ - date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count) + date + for date, count in _dict_count_trade.items() + if count >= max(int(count * threshold), minimum_count) ] return calendar @@ -210,14 +289,21 @@ def get_hs_stock_symbols() -> list: data = resp.json() # Check if response contains valid data - if not data or "data" not in data or not data["data"] or "diff" not in data["data"]: + if ( + not data + or "data" not in data + or not data["data"] + or "diff" not in data["data"] + ): logger.warning(f"Invalid response structure on page {page}") break # fetch the current page data current_symbols = [_v["f12"] for _v in data["data"]["diff"]] - if not current_symbols: # It's the last page if there is no data in current page + if ( + not current_symbols + ): # It's the last page if there is no data in current page logger.info(f"Last page reached: {page - 1}") break @@ -238,7 +324,9 @@ def get_hs_stock_symbols() -> list: f"Request to {base_url} failed with status code {resp.status_code}" ) from e except Exception as e: - logger.warning("An error occurred while extracting data from the response.") + logger.warning( + "An error occurred while extracting data from the response." + ) raise if len(_symbols) < 3900: @@ -246,7 +334,11 @@ def get_hs_stock_symbols() -> list: # Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched. _symbols = [ - _symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None + ( + _symbol + ".ss" + if _symbol.startswith("6") + else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None + ) for _symbol in _symbols ] _symbols = [_symbol for _symbol in _symbols if _symbol is not None] @@ -292,7 +384,10 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: - _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + _symbols = [ + _v["f12"].replace("_", "-P") + for _v in resp.json()["data"]["diff"].values() + ] except Exception as e: logger.warning(f"request error: {e}") raise @@ -357,7 +452,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) + _US_SYMBOLS = sorted( + set( + map( + _format, + filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), + ) + ) + ) return _US_SYMBOLS @@ -427,7 +529,9 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list: children = tbody.findChildren("a", recursive=True) for child in children: - _symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0]) + _symbols.append( + str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0] + ) return _symbols @@ -471,7 +575,10 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + for sub_data in re.findall( + r"[\[](.*?)[\]]", + resp.content.decode().split("= [")[-1].replace("];", ""), + ): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other information, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -552,7 +659,9 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): +def get_trading_date_by_shift( + trading_list: list, trading_date: pd.Timestamp, shift: int = 1 +): """get trading date by shift Parameters @@ -650,17 +759,28 @@ def get_instruments( $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies """ - _cur_module = importlib.import_module("data_collector.{}.collector".format(market_index)) + _cur_module = importlib.import_module( + "data_collector.{}.collector".format(market_index) + ) obj = getattr(_cur_module, f"{index_name.upper()}Index")( - qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep + qlib_dir=qlib_dir, + index_name=index_name, + freq=freq, + request_retry=request_retry, + retry_sleep=retry_sleep, ) getattr(obj, method)() -def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): +def _get_all_1d_data( + _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame +): df = copy.deepcopy(_1d_data_all) df.reset_index(inplace=True) - df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True) + df.rename( + columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, + inplace=True, + ) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -723,8 +843,12 @@ def calc_adjusted_price( df[_date_field_name] = pd.to_datetime(df[_date_field_name]) # get 1d data from qlib _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") - _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") - data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all) + _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) + data_1d: pd.DataFrame = get_1d_data( + _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all + ) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] @@ -744,27 +868,38 @@ def calc_adjusted_price( # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` def _calc_factor(df_1d: pd.DataFrame): try: - _date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date()) - df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + _date = pd.Timestamp( + pd.Timestamp(df_1d[_date_field_name].iloc[0]).date() + ) + df_1d["factor"] = ( + data_1d.loc[_date]["close"] + / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + ) df_1d["paused"] = data_1d.loc[_date]["paused"] except Exception: df_1d["factor"] = np.nan df_1d["paused"] = np.nan return df_1d - df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor) + df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply( + _calc_factor + ) if consistent_1d: # the date sequence is consistent with 1d df.set_index(_date_field_name, inplace=True) df = df.reindex( generate_minutes_calendar_from_daily( - calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()), + calendars=pd.to_datetime( + data_1d.reset_index()[_date_field_name].drop_duplicates() + ), freq=frequence, am_range=("09:30:00", "11:29:00"), pm_range=("13:00:00", "14:59:00"), ) ) - df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name] + df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][ + _symbol_field_name + ] df.index.names = [_date_field_name] df.reset_index(inplace=True) for _col in ["open", "close", "high", "low", "volume"]: @@ -806,7 +941,10 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name): _date_field_name, _symbol_field_name, } - if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all(): + if ( + _df.loc[:, list(check_fields)].isna().values.all() + or (_df["volume"] == 0).all() + ): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1