From 5a2bb14273236702f0165d9b1bd5e1270e57d683 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 28 Jan 2026 14:42:49 +0800 Subject: [PATCH] fix: lint with black --- scripts/data_collector/utils.py | 110 +++++++------------------------- 1 file changed, 24 insertions(+), 86 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index d595d7ff6..f67791f2d 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -53,9 +53,7 @@ CHROME_UA_POOL = [ "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/121.0.0.0 Safari/537.36", # Linux - "Mozilla/5.0 (X11; Linux x86_64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (X11; Linux x86_64) " "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/120.0.0.0 Safari/537.36", ] _BENCH_CALENDAR_LIST = None @@ -129,26 +127,11 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - if ( - bench_code.startswith("US_") - or bench_code.startswith("IN_") - or bench_code.startswith("BR_") - ): + if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"): print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code])) - print( - Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( - interval="1d", period="max" - ) - ) - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( - interval="1d", period="max" - ) - calendar = ( - df.index.get_level_values(level="date") - .map(pd.Timestamp) - .unique() - .tolist() - ) + print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")) + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") + calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() else: if bench_code.upper() == "ALL": import akshare as ak # pylint: disable=C0415 @@ -157,10 +140,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: trade_date_list = trade_date_df["trade_date"].tolist() trade_date_list = [pd.Timestamp(d) for d in trade_date_list] dates = pd.DatetimeIndex(trade_date_list) - filtered_dates = dates[ - (dates >= "2000-01-04") - & (dates <= pd.Timestamp.today().normalize()) - ] + filtered_dates = dates[(dates >= "2000-01-04") & (dates <= pd.Timestamp.today().normalize())] calendar = filtered_dates.tolist() else: calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) @@ -225,9 +205,7 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = { - date: _number_all_funds for date in _dict_count_trade - } # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -235,9 +213,7 @@ def get_calendar_list_by_ratio( _dict_count_founding[date] -= 1 calendar = [ - date - for date, count in _dict_count_trade.items() - if count >= max(int(count * threshold), minimum_count) + date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count) ] return calendar @@ -289,21 +265,14 @@ def get_hs_stock_symbols() -> list: data = resp.json() # Check if response contains valid data - if ( - not data - or "data" not in data - or not data["data"] - or "diff" not in data["data"] - ): + if not data or "data" not in data or not data["data"] or "diff" not in data["data"]: logger.warning(f"Invalid response structure on page {page}") break # fetch the current page data current_symbols = [_v["f12"] for _v in data["data"]["diff"]] - if ( - not current_symbols - ): # It's the last page if there is no data in current page + if not current_symbols: # It's the last page if there is no data in current page logger.info(f"Last page reached: {page - 1}") break @@ -324,9 +293,7 @@ def get_hs_stock_symbols() -> list: f"Request to {base_url} failed with status code {resp.status_code}" ) from e except Exception as e: - logger.warning( - "An error occurred while extracting data from the response." - ) + logger.warning("An error occurred while extracting data from the response.") raise if len(_symbols) < 3900: @@ -384,10 +351,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: - _symbols = [ - _v["f12"].replace("_", "-P") - for _v in resp.json()["data"]["diff"].values() - ] + _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] except Exception as e: logger.warning(f"request error: {e}") raise @@ -529,9 +493,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list: children = tbody.findChildren("a", recursive=True) for child in children: - _symbols.append( - str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0] - ) + _symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0]) return _symbols @@ -659,9 +621,7 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift( - trading_list: list, trading_date: pd.Timestamp, shift: int = 1 -): +def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): """get trading date by shift Parameters @@ -759,9 +719,7 @@ def get_instruments( $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies """ - _cur_module = importlib.import_module( - "data_collector.{}.collector".format(market_index) - ) + _cur_module = importlib.import_module("data_collector.{}.collector".format(market_index)) obj = getattr(_cur_module, f"{index_name.upper()}Index")( qlib_dir=qlib_dir, index_name=index_name, @@ -772,9 +730,7 @@ def get_instruments( getattr(obj, method)() -def _get_all_1d_data( - _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame -): +def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): df = copy.deepcopy(_1d_data_all) df.reset_index(inplace=True) df.rename( @@ -843,12 +799,8 @@ def calc_adjusted_price( df[_date_field_name] = pd.to_datetime(df[_date_field_name]) # get 1d data from qlib _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") - _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime( - "%Y-%m-%d" - ) - data_1d: pd.DataFrame = get_1d_data( - _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all - ) + _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] @@ -868,38 +820,27 @@ def calc_adjusted_price( # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` def _calc_factor(df_1d: pd.DataFrame): try: - _date = pd.Timestamp( - pd.Timestamp(df_1d[_date_field_name].iloc[0]).date() - ) - df_1d["factor"] = ( - data_1d.loc[_date]["close"] - / df_1d.loc[df_1d["close"].last_valid_index()]["close"] - ) + _date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date()) + df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] df_1d["paused"] = data_1d.loc[_date]["paused"] except Exception: df_1d["factor"] = np.nan df_1d["paused"] = np.nan return df_1d - df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply( - _calc_factor - ) + df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor) if consistent_1d: # the date sequence is consistent with 1d df.set_index(_date_field_name, inplace=True) df = df.reindex( generate_minutes_calendar_from_daily( - calendars=pd.to_datetime( - data_1d.reset_index()[_date_field_name].drop_duplicates() - ), + calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()), freq=frequence, am_range=("09:30:00", "11:29:00"), pm_range=("13:00:00", "14:59:00"), ) ) - df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][ - _symbol_field_name - ] + df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name] df.index.names = [_date_field_name] df.reset_index(inplace=True) for _col in ["open", "close", "high", "low", "volume"]: @@ -941,10 +882,7 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name): _date_field_name, _symbol_field_name, } - if ( - _df.loc[:, list(check_fields)].isna().values.all() - or (_df["volume"] == 0).all() - ): + if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all(): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1