diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 62d6e4eb1..ceb0735cb 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -69,15 +69,8 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: if bench_code.startswith("US_"): - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( - interval="1d", period="max" - ) - calendar = ( - df.index.get_level_values(level="date") - .map(pd.Timestamp) - .unique() - .tolist() - ) + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") + calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() else: if bench_code.upper() == "ALL": @@ -85,9 +78,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: def _get_calendar(month): _cal = [] try: - resp = requests.get( - SZSE_CALENDAR_URL.format(month=month, random=random.random) - ).json() + resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json() for _r in resp["data"]: if int(_r["jybz"]): _cal.append(pd.Timestamp(_r["jyrq"])) @@ -95,11 +86,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: raise ValueError(f"{month}-->{e}") return _cal - month_range = pd.date_range( - start="2000-01", - end=pd.Timestamp.now() + pd.Timedelta(days=31), - freq="M", - ) + month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") calendar = [] for _m in month_range: cal = _get_calendar(_m.strftime("%Y-%m")) @@ -169,9 +156,7 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = { - date: _number_all_funds for date in _dict_count_trade.keys() - } # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -181,8 +166,7 @@ def get_calendar_list_by_ratio( calendar = [ date for date in _dict_count_trade - if _dict_count_trade[date] - >= max(int(_dict_count_founding[date] * threshold), minimum_count) + if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) ] return calendar @@ -204,9 +188,7 @@ def get_hs_stock_symbols() -> list: _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), - etree.HTML(resp.text).xpath( - "//div[@class='result']/ul//li/a/text()" - ), + etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), ) ) time.sleep(3) @@ -250,10 +232,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: if resp.status_code != 200: raise ValueError("request error") try: - _symbols = [ - _v["f12"].replace("_", "-P") - for _v in resp.json()["data"]["diff"].values() - ] + _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] except Exception as e: logger.warning(f"request error: {e}") raise @@ -315,14 +294,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted( - set( - map( - _format, - filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), - ) - ) - ) + _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) return _US_SYMBOLS @@ -385,10 +357,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall( - r"[\[](.*?)[\]]", - resp.content.decode().split("= [")[-1].replace("];", ""), - ): + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -467,9 +436,7 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift( - trading_list: list, trading_date: pd.Timestamp, shift: int = 1 -): +def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): """get trading date by shift Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index cf97a0c7e..d7518eca8 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -93,9 +93,7 @@ class YahooCollector(BaseCollector): def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max( - self.start_datetime, self.DEFAULT_START_DATETIME_1MIN - ) + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) elif self.interval == self.INTERVAL_1d: pass else: @@ -119,9 +117,7 @@ class YahooCollector(BaseCollector): raise NotImplementedError("rewrite get_timezone") @staticmethod - def get_data_from_remote( - symbol, interval, start, end, show_1min_logging: bool = False - ): + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): error_msg = f"{symbol}-{interval}-{start}-{end}" def _show_logging_func(): @@ -130,16 +126,13 @@ class YahooCollector(BaseCollector): interval = "1m" if interval in ["1m", "1min"] else interval try: - _resp = Ticker(symbol, asynchronous=False).history( - interval=interval, start=start, end=end - ) + _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() elif isinstance(_resp, dict): _temp_data = _resp.get(symbol, {}) if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) - and _temp_data.get("indicators", {}).get("quote", None) is None + isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None ): _show_logging_func() else: @@ -148,11 +141,7 @@ class YahooCollector(BaseCollector): logger.warning(f"{error_msg}:{e}") def get_data( - self, - symbol: str, - interval: str, - start_datetime: pd.Timestamp, - end_datetime: pd.Timestamp, + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> pd.DataFrame: @deco_retry(retry_sleep=self.delay) def _get_simple(start_, end_): @@ -225,35 +214,21 @@ class YahooCollectorCN1d(YahooCollectorCN): _format = "%Y%m%d" _begin = self.start_datetime.strftime(_format) _end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format) - for _index_name, _index_code in { - "csi300": "000300", - "csi100": "000903", - }.items(): + for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: df = pd.DataFrame( map( lambda x: x.split(","), - requests.get( - INDEX_BENCH_URL.format( - index_code=_index_code, begin=_begin, end=_end - ) - ).json()["data"]["klines"], + requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ + "data" + ]["klines"], ) ) except Exception as e: logger.warning(f"get {_index_name} error: {e}") continue - df.columns = [ - "date", - "open", - "close", - "high", - "low", - "volume", - "money", - "change", - ] + df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] df["date"] = pd.to_datetime(df["date"]) df = df.astype(float, errors="ignore") df["adjclose"] = df["close"] @@ -351,18 +326,13 @@ class YahooNormalize(BaseNormalize): df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()) - .date() : pd.Timestamp(df.index.max()) - .date() + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) - df.loc[ - (df["volume"] <= 0) | np.isnan(df["volume"]), - set(df.columns) - {symbol_field_name}, - ] = np.nan + df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan _tmp_series = df["close"].fillna(method="ffill") _tmp_shift_series = _tmp_series.shift(1) if last_close is not None: @@ -377,9 +347,7 @@ class YahooNormalize(BaseNormalize): def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_yahoo( - df, self._calendar_list, self._date_field_name, self._symbol_field_name - ) + df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name) # adjusted price df = self.adjusted_price(df) return df @@ -450,11 +418,7 @@ class YahooNormalize1d(YahooNormalize, ABC): class YahooNormalize1dExtend(YahooNormalize1d): def __init__( - self, - old_qlib_data_dir: [str, Path], - date_field_name: str = "date", - symbol_field_name: str = "symbol", - **kwargs, + self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs ): """ @@ -483,9 +447,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): return df def _get_close(self, df: pd.DataFrame, field_name: str): - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() _df = self.old_qlib_data.loc(axis=0)[_symbol] _close = _df.loc[_df.last_valid_index()][field_name] return _close @@ -505,9 +467,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): return _close def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() try: _df = self.old_qlib_data.loc(axis=0)[_symbol] _date = _df.index.max() @@ -535,11 +495,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): ) # normalize df = self.normalize_yahoo( - df, - self._calendar_list, - self._date_field_name, - self._symbol_field_name, - last_close=_last_close, + df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close ) # adjusted price df = self.adjusted_price(df) @@ -577,14 +533,10 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] """ - data_1d = YahooCollector.get_data_from_remote( - self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end - ) + data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) if not (data_1d is None or data_1d.empty): _class_name = self.__class__.__name__.replace("min", "d") - _class: type(YahooNormalize) = getattr( - importlib.import_module("collector"), _class_name - ) + _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) data_1d_obj = _class(self._date_field_name, self._symbol_field_name) data_1d = data_1d_obj.normalize(data_1d) return data_1d @@ -597,12 +549,8 @@ class YahooNormalize1min(YahooNormalize, ABC): df = df.sort_values(self._date_field_name) symbol = df.iloc[0][self._symbol_field_name] # get 1d data from yahoo - _start = pd.Timestamp(df[self._date_field_name].min()).strftime( - self.DAILY_FORMAT - ) - _end = ( - pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1) - ).strftime(self.DAILY_FORMAT) + _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) + _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: @@ -613,9 +561,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # NOTE: volume is np.nan or volume <= 0, paused = 1 # FIXME: find a more accurate data source data_1d["paused"] = 0 - data_1d.loc[ - (data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused" - ] = 1 + data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 data_1d = data_1d.set_index(self._date_field_name) # add factor from 1d data @@ -623,9 +569,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - df["date_tmp"] = df[self._date_field_name].apply( - lambda x: pd.Timestamp(x).date() - ) + df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) df.set_index("date_tmp", inplace=True) df.loc[:, "factor"] = data_1d["close"] / df["close"] df.loc[:, "paused"] = data_1d["paused"] @@ -636,16 +580,12 @@ class YahooNormalize1min(YahooNormalize, ABC): df.set_index(self._date_field_name, inplace=True) df = df.reindex( self.generate_1min_from_daily( - pd.to_datetime( - data_1d.reset_index()[ - self._date_field_name - ].drop_duplicates() - ) + pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) ) ) - df[self._symbol_field_name] = df.loc[ - df[self._symbol_field_name].first_valid_index() - ][self._symbol_field_name] + df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ] df.index.names = [self._date_field_name] df.reset_index(inplace=True) for _col in self.COLUMNS: @@ -663,9 +603,7 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply( - lambda x: pd.Timestamp(x).date() - ) + df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan @@ -685,10 +623,7 @@ class YahooNormalize1min(YahooNormalize, ABC): self._date_field_name, self._symbol_field_name, } - if ( - _df.loc[:, check_fields].isna().values.all() - or (_df["volume"] == 0).all() - ): + if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1 @@ -723,11 +658,7 @@ class YahooNormalize1minOffline(YahooNormalize1min): """Normalised to 1min using local 1d data""" def __init__( - self, - qlib_data_1d_dir: [str, Path], - date_field_name: str = "date", - symbol_field_name: str = "symbol", - **kwargs, + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs ): """ @@ -741,9 +672,7 @@ class YahooNormalize1minOffline(YahooNormalize1min): symbol field name, default is symbol """ self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalize1minOffline, self).__init__( - date_field_name, symbol_field_name - ) + super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -758,19 +687,9 @@ class YahooNormalize1minOffline(YahooNormalize1min): from qlib.data import D qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features( - D.instruments("all"), - ["$paused", "$volume", "$factor", "$close"], - freq="day", - ) + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") df.reset_index(inplace=True) - df.rename( - columns={ - "datetime": self._date_field_name, - "instrument": self._symbol_field_name, - }, - inplace=True, - ) + df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -838,11 +757,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): def symbol_to_yahoo(self, symbol): if "." not in symbol: _exchange = symbol[:2] - _exchange = ( - ("ss" if _exchange.islower() else "SS") - if _exchange.lower() == "sh" - else _exchange - ) + _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange symbol = symbol[2:] + "." + _exchange return symbol @@ -851,14 +766,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): class Run(BaseRun): - def __init__( - self, - source_dir=None, - normalize_dir=None, - max_workers=1, - interval="1d", - region=REGION_CN, - ): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ Parameters @@ -930,13 +838,7 @@ class Run(BaseRun): $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ super(Run, self).download_data( - max_collector_count, - delay, - start, - end, - self.interval, - check_data_length, - limit_nums, + max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums ) def normalize_data( @@ -971,25 +873,16 @@ class Run(BaseRun): $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ if self.interval.lower() == "1min": - if ( - qlib_data_1d_dir is None - or not Path(qlib_data_1d_dir).expanduser().exists() - ): + if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): raise ValueError( "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, - symbol_field_name, - end_date=end_date, - qlib_data_1d_dir=qlib_data_1d_dir, + date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir ) def normalize_data_1d_extend( - self, - old_qlib_data_dir, - date_field_name: str = "date", - symbol_field_name: str = "symbol", + self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" ): """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) @@ -1120,30 +1013,19 @@ class Run(BaseRun): # start/end date if trading_date is None: trading_date = datetime.datetime.now().strftime("%Y-%m-%d") - logger.warning( - f"trading_date is None, use the current date: {trading_date}" - ) + logger.warning(f"trading_date is None, use the current date: {trading_date}") if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime( - "%Y-%m-%d" - ) + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data( - target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region - ) + GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data( - delay=delay, - start=trading_date, - end=end_date, - check_data_length=check_data_length, - ) + self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) @@ -1165,18 +1047,11 @@ class Run(BaseRun): # parse index _region = self.region.lower() if _region not in ["cn", "us"]: - logger.warning( - f"Unsupported region: region={_region}, component downloads will be ignored" - ) + logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") return - index_list = ( - ["CSI100", "CSI300"] - if _region == "cn" - else ["SP500", "NASDAQ100", "DJIA", "SP400"] - ) + index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] get_instruments = getattr( - importlib.import_module(f"data_collector.{_region}_index.collector"), - "get_instruments", + importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" ) for _index in index_list: get_instruments(str(qlib_data_1d_dir), _index)