diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5d5822f91..1a08c514f 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import os import time import bisect import pickle @@ -14,6 +15,9 @@ import pandas as pd from lxml import etree from loguru import logger from yahooquery import Ticker +from tqdm import tqdm +from functools import partial +from concurrent.futures import ProcessPoolExecutor HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -94,6 +98,57 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar +def return_date_list(source_dir, date_field_name, file_path): + df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) + + return df[date_field_name].to_list() + + +def get_calendar_list_by_ratio( + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 +) -> list: + """get calendar list by selecting the date when few funds trade in this day + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + date_field_name: str + date field name, default is date + threshold: float + threshold to exclude some days when few funds trade in this day, default 0.5 + max_workers: int + Concurrent number, default is 16 + + Returns + ------- + history calendar list + """ + logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") + + _number_all_funds = len(os.listdir(source_dir)) + + _list_all_date = dict() + + _fun = partial(return_date_list, source_dir, date_field_name) + + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, os.listdir(source_dir)): + for date in date_list: + if date in _list_all_date.keys(): + _list_all_date[date] += 1 + else: + _list_all_date[date] = 0 + + p_bar.update() + + _threshold_number = int(_number_all_funds * threshold) + calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + + return calendar + + def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols