1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

fix: strategies for enhancing crawlers

This commit is contained in:
Linlang
2026-01-28 14:39:07 +08:00
parent 8355990ac5
commit fb606ec874

View File

@@ -3,10 +3,12 @@
import re
import copy
import datetime
import importlib
import time
import bisect
import pickle
import random
import requests
import functools
from pathlib import Path
@@ -23,7 +25,7 @@ from bs4 import BeautifulSoup
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={start}&end={end}"
SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/monthList?month={month}&random={random}"
CALENDAR_BENCH_URL_MAP = {
@@ -38,6 +40,24 @@ CALENDAR_BENCH_URL_MAP = {
"BR_ALL": "^BVSP",
}
CHROME_UA_POOL = [
# Windows
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.6167.85 Safari/537.36",
# macOS
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.0.0 Safari/537.36",
# Linux
"Mozilla/5.0 (X11; Linux x86_64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
]
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
@@ -51,6 +71,16 @@ _CALENDAR_MAP = {}
MINIMUM_SYMBOLS_NUM = 3900
def build_headers():
return {
"User-Agent": random.choice(CHROME_UA_POOL),
"Accept": "application/json,text/plain,*/*",
"Accept-Language": "zh-CN,zh;q=0.9",
"Referer": "https://quote.eastmoney.com/",
"Connection": "keep-alive",
}
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
"""get SH/SZ history calendar list
@@ -67,16 +97,58 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
logger.info(f"get calendar list: {bench_code}......")
def _get_calendar(url):
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
session = requests.Session()
session.headers.update(build_headers())
current_datetime = datetime.datetime.now()
cur_year = current_datetime.year
res_list = []
for per_year in range(2000, cur_year + 1):
start = f"{per_year}0101"
end = f"{per_year}1231"
formatted_url = url.format(start=start, end=end)
try:
resp = session.get(formatted_url, timeout=10)
resp.raise_for_status()
payload = resp.json()
data = payload.get("data")
if not data or "klines" not in data:
continue
klines = data["klines"]
res_list.extend(pd.Timestamp(x.split(",")[0]) for x in klines)
except requests.RequestException as e:
continue
time.sleep(random.uniform(0.5, 1.2))
return sorted(set(res_list))
# _value_list = requests.get(url, timeout=None).json()["data"]["klines"]
# return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"):
if (
bench_code.startswith("US_")
or bench_code.startswith("IN_")
or bench_code.startswith("BR_")
):
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]))
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max"))
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
print(
Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(
interval="1d", period="max"
)
)
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(
interval="1d", period="max"
)
calendar = (
df.index.get_level_values(level="date")
.map(pd.Timestamp)
.unique()
.tolist()
)
else:
if bench_code.upper() == "ALL":
import akshare as ak # pylint: disable=C0415
@@ -85,7 +157,10 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
trade_date_list = trade_date_df["trade_date"].tolist()
trade_date_list = [pd.Timestamp(d) for d in trade_date_list]
dates = pd.DatetimeIndex(trade_date_list)
filtered_dates = dates[(dates >= "2000-01-04") & (dates <= pd.Timestamp.today().normalize())]
filtered_dates = dates[
(dates >= "2000-01-04")
& (dates <= pd.Timestamp.today().normalize())
]
calendar = filtered_dates.tolist()
else:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
@@ -150,7 +225,9 @@ def get_calendar_list_by_ratio(
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count}
_dict_count_founding = {
date: _number_all_funds for date in _dict_count_trade
} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
for oldest_date in all_oldest_list:
for date in _dict_count_founding.keys():
@@ -158,7 +235,9 @@ def get_calendar_list_by_ratio(
_dict_count_founding[date] -= 1
calendar = [
date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count)
date
for date, count in _dict_count_trade.items()
if count >= max(int(count * threshold), minimum_count)
]
return calendar
@@ -210,14 +289,21 @@ def get_hs_stock_symbols() -> list:
data = resp.json()
# Check if response contains valid data
if not data or "data" not in data or not data["data"] or "diff" not in data["data"]:
if (
not data
or "data" not in data
or not data["data"]
or "diff" not in data["data"]
):
logger.warning(f"Invalid response structure on page {page}")
break
# fetch the current page data
current_symbols = [_v["f12"] for _v in data["data"]["diff"]]
if not current_symbols: # It's the last page if there is no data in current page
if (
not current_symbols
): # It's the last page if there is no data in current page
logger.info(f"Last page reached: {page - 1}")
break
@@ -238,7 +324,9 @@ def get_hs_stock_symbols() -> list:
f"Request to {base_url} failed with status code {resp.status_code}"
) from e
except Exception as e:
logger.warning("An error occurred while extracting data from the response.")
logger.warning(
"An error occurred while extracting data from the response."
)
raise
if len(_symbols) < 3900:
@@ -246,7 +334,11 @@ def get_hs_stock_symbols() -> list:
# Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.
_symbols = [
_symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
(
_symbol + ".ss"
if _symbol.startswith("6")
else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
)
for _symbol in _symbols
]
_symbols = [_symbol for _symbol in _symbols if _symbol is not None]
@@ -292,7 +384,10 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
raise ValueError("request error")
try:
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
_symbols = [
_v["f12"].replace("_", "-P")
for _v in resp.json()["data"]["diff"].values()
]
except Exception as e:
logger.warning(f"request error: {e}")
raise
@@ -357,7 +452,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
s_ = s_.strip("*")
return s_
_US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
_US_SYMBOLS = sorted(
set(
map(
_format,
filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols),
)
)
)
return _US_SYMBOLS
@@ -427,7 +529,9 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
children = tbody.findChildren("a", recursive=True)
for child in children:
_symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0])
_symbols.append(
str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0]
)
return _symbols
@@ -471,7 +575,10 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
raise ValueError("request error")
try:
_symbols = []
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
for sub_data in re.findall(
r"[\[](.*?)[\]]",
resp.content.decode().split("= [")[-1].replace("];", ""),
):
data = sub_data.replace('"', "").replace("'", "")
# TODO: do we need other information, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
_symbols.append(data.split(",")[0])
@@ -552,7 +659,9 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
return deco_func(retry) if callable(retry) else deco_func
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
def get_trading_date_by_shift(
trading_list: list, trading_date: pd.Timestamp, shift: int = 1
):
"""get trading date by shift
Parameters
@@ -650,17 +759,28 @@ def get_instruments(
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("data_collector.{}.collector".format(market_index))
_cur_module = importlib.import_module(
"data_collector.{}.collector".format(market_index)
)
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
qlib_dir=qlib_dir,
index_name=index_name,
freq=freq,
request_retry=request_retry,
retry_sleep=retry_sleep,
)
getattr(obj, method)()
def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):
def _get_all_1d_data(
_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame
):
df = copy.deepcopy(_1d_data_all)
df.reset_index(inplace=True)
df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True)
df.rename(
columns={"datetime": _date_field_name, "instrument": _symbol_field_name},
inplace=True,
)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
@@ -723,8 +843,12 @@ def calc_adjusted_price(
df[_date_field_name] = pd.to_datetime(df[_date_field_name])
# get 1d data from qlib
_start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d")
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all)
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime(
"%Y-%m-%d"
)
data_1d: pd.DataFrame = get_1d_data(
_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all
)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
@@ -744,27 +868,38 @@ def calc_adjusted_price(
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
def _calc_factor(df_1d: pd.DataFrame):
try:
_date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date())
df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
_date = pd.Timestamp(
pd.Timestamp(df_1d[_date_field_name].iloc[0]).date()
)
df_1d["factor"] = (
data_1d.loc[_date]["close"]
/ df_1d.loc[df_1d["close"].last_valid_index()]["close"]
)
df_1d["paused"] = data_1d.loc[_date]["paused"]
except Exception:
df_1d["factor"] = np.nan
df_1d["paused"] = np.nan
return df_1d
df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor)
df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(
_calc_factor
)
if consistent_1d:
# the date sequence is consistent with 1d
df.set_index(_date_field_name, inplace=True)
df = df.reindex(
generate_minutes_calendar_from_daily(
calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()),
calendars=pd.to_datetime(
data_1d.reset_index()[_date_field_name].drop_duplicates()
),
freq=frequence,
am_range=("09:30:00", "11:29:00"),
pm_range=("13:00:00", "14:59:00"),
)
)
df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name]
df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][
_symbol_field_name
]
df.index.names = [_date_field_name]
df.reset_index(inplace=True)
for _col in ["open", "close", "high", "low", "volume"]:
@@ -806,7 +941,10 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):
_date_field_name,
_symbol_field_name,
}
if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all():
if (
_df.loc[:, list(check_fields)].isna().values.all()
or (_df["volume"] == 0).all()
):
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1