From 457dcaa4667970f6a02163c32f64d008d28edca6 Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Wed, 14 Jul 2021 20:12:00 +0530 Subject: [PATCH] cleaned with black --- scripts/data_collector/utils.py | 57 ++++-- scripts/data_collector/yahoo/collector.py | 219 +++++++++++++++++----- 2 files changed, 217 insertions(+), 59 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 72bd1be18..62d6e4eb1 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -69,8 +69,15 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: if bench_code.startswith("US_"): - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") - calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + calendar = ( + df.index.get_level_values(level="date") + .map(pd.Timestamp) + .unique() + .tolist() + ) else: if bench_code.upper() == "ALL": @@ -78,7 +85,9 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: def _get_calendar(month): _cal = [] try: - resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json() + resp = requests.get( + SZSE_CALENDAR_URL.format(month=month, random=random.random) + ).json() for _r in resp["data"]: if int(_r["jybz"]): _cal.append(pd.Timestamp(_r["jyrq"])) @@ -86,7 +95,11 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: raise ValueError(f"{month}-->{e}") return _cal - month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") + month_range = pd.date_range( + start="2000-01", + end=pd.Timestamp.now() + pd.Timedelta(days=31), + freq="M", + ) calendar = [] for _m in month_range: cal = _get_calendar(_m.strftime("%Y-%m")) @@ -156,7 +169,9 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + _dict_count_founding = { + date: _number_all_funds for date in _dict_count_trade.keys() + } # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -166,7 +181,8 @@ def get_calendar_list_by_ratio( calendar = [ date for date in _dict_count_trade - if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + if _dict_count_trade[date] + >= max(int(_dict_count_founding[date] * threshold), minimum_count) ] return calendar @@ -188,7 +204,9 @@ def get_hs_stock_symbols() -> list: _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), - etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), + etree.HTML(resp.text).xpath( + "//div[@class='result']/ul//li/a/text()" + ), ) ) time.sleep(3) @@ -232,7 +250,10 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: if resp.status_code != 200: raise ValueError("request error") try: - _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + _symbols = [ + _v["f12"].replace("_", "-P") + for _v in resp.json()["data"]["diff"].values() + ] except Exception as e: logger.warning(f"request error: {e}") raise @@ -294,7 +315,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) + _US_SYMBOLS = sorted( + set( + map( + _format, + filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), + ) + ) + ) return _US_SYMBOLS @@ -313,7 +341,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list: url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv" df = pd.read_csv(url) df = df.rename(columns={"SYMBOL": "Symbol"}) - df['Symbol'] = df['Symbol'] + ".NS" + df["Symbol"] = df["Symbol"] + ".NS" _symbols = df["Symbol"].dropna() _symbols = _symbols.unique().tolist() return _symbols @@ -357,7 +385,10 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + for sub_data in re.findall( + r"[\[](.*?)[\]]", + resp.content.decode().split("= [")[-1].replace("];", ""), + ): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -436,7 +467,9 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): +def get_trading_date_by_shift( + trading_list: list, trading_date: pd.Timestamp, shift: int = 1 +): """get trading date by shift Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index e262dac19..cf97a0c7e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -93,7 +93,9 @@ class YahooCollector(BaseCollector): def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + self.start_datetime = max( + self.start_datetime, self.DEFAULT_START_DATETIME_1MIN + ) elif self.interval == self.INTERVAL_1d: pass else: @@ -117,7 +119,9 @@ class YahooCollector(BaseCollector): raise NotImplementedError("rewrite get_timezone") @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + def get_data_from_remote( + symbol, interval, start, end, show_1min_logging: bool = False + ): error_msg = f"{symbol}-{interval}-{start}-{end}" def _show_logging_func(): @@ -126,13 +130,16 @@ class YahooCollector(BaseCollector): interval = "1m" if interval in ["1m", "1min"] else interval try: - _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) + _resp = Ticker(symbol, asynchronous=False).history( + interval=interval, start=start, end=end + ) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() elif isinstance(_resp, dict): _temp_data = _resp.get(symbol, {}) if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None + isinstance(_resp, dict) + and _temp_data.get("indicators", {}).get("quote", None) is None ): _show_logging_func() else: @@ -141,7 +148,11 @@ class YahooCollector(BaseCollector): logger.warning(f"{error_msg}:{e}") def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: @deco_retry(retry_sleep=self.delay) def _get_simple(start_, end_): @@ -214,21 +225,35 @@ class YahooCollectorCN1d(YahooCollectorCN): _format = "%Y%m%d" _begin = self.start_datetime.strftime(_format) _end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format) - for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): + for _index_name, _index_code in { + "csi300": "000300", + "csi100": "000903", + }.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: df = pd.DataFrame( map( lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ - "data" - ]["klines"], + requests.get( + INDEX_BENCH_URL.format( + index_code=_index_code, begin=_begin, end=_end + ) + ).json()["data"]["klines"], ) ) except Exception as e: logger.warning(f"get {_index_name} error: {e}") continue - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df.columns = [ + "date", + "open", + "close", + "high", + "low", + "volume", + "money", + "change", + ] df["date"] = pd.to_datetime(df["date"]) df = df.astype(float, errors="ignore") df["adjclose"] = df["close"] @@ -283,7 +308,7 @@ class YahooCollectorUS1min(YahooCollectorUS): class YahooCollectorIN(YahooCollector, ABC): def get_instrument_list(self): logger.info("get INDIA stock symbols......") - symbols = get_in_stock_symbols() + symbols = get_in_stock_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols @@ -326,13 +351,18 @@ class YahooNormalize(BaseNormalize): df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan + df.loc[ + (df["volume"] <= 0) | np.isnan(df["volume"]), + set(df.columns) - {symbol_field_name}, + ] = np.nan _tmp_series = df["close"].fillna(method="ffill") _tmp_shift_series = _tmp_series.shift(1) if last_close is not None: @@ -347,7 +377,9 @@ class YahooNormalize(BaseNormalize): def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_yahoo( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) # adjusted price df = self.adjusted_price(df) return df @@ -418,7 +450,11 @@ class YahooNormalize1d(YahooNormalize, ABC): class YahooNormalize1dExtend(YahooNormalize1d): def __init__( - self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + old_qlib_data_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -447,7 +483,9 @@ class YahooNormalize1dExtend(YahooNormalize1d): return df def _get_close(self, df: pd.DataFrame, field_name: str): - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ].upper() _df = self.old_qlib_data.loc(axis=0)[_symbol] _close = _df.loc[_df.last_valid_index()][field_name] return _close @@ -467,7 +505,9 @@ class YahooNormalize1dExtend(YahooNormalize1d): return _close def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ].upper() try: _df = self.old_qlib_data.loc(axis=0)[_symbol] _date = _df.index.max() @@ -495,7 +535,11 @@ class YahooNormalize1dExtend(YahooNormalize1d): ) # normalize df = self.normalize_yahoo( - df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close + df, + self._calendar_list, + self._date_field_name, + self._symbol_field_name, + last_close=_last_close, ) # adjusted price df = self.adjusted_price(df) @@ -533,10 +577,14 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] """ - data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) + data_1d = YahooCollector.get_data_from_remote( + self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end + ) if not (data_1d is None or data_1d.empty): _class_name = self.__class__.__name__.replace("min", "d") - _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) + _class: type(YahooNormalize) = getattr( + importlib.import_module("collector"), _class_name + ) data_1d_obj = _class(self._date_field_name, self._symbol_field_name) data_1d = data_1d_obj.normalize(data_1d) return data_1d @@ -549,8 +597,12 @@ class YahooNormalize1min(YahooNormalize, ABC): df = df.sort_values(self._date_field_name) symbol = df.iloc[0][self._symbol_field_name] # get 1d data from yahoo - _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) - _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) + _start = pd.Timestamp(df[self._date_field_name].min()).strftime( + self.DAILY_FORMAT + ) + _end = ( + pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1) + ).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: @@ -561,7 +613,9 @@ class YahooNormalize1min(YahooNormalize, ABC): # NOTE: volume is np.nan or volume <= 0, paused = 1 # FIXME: find a more accurate data source data_1d["paused"] = 0 - data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 + data_1d.loc[ + (data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused" + ] = 1 data_1d = data_1d.set_index(self._date_field_name) # add factor from 1d data @@ -569,7 +623,9 @@ class YahooNormalize1min(YahooNormalize, ABC): # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["date_tmp"] = df[self._date_field_name].apply( + lambda x: pd.Timestamp(x).date() + ) df.set_index("date_tmp", inplace=True) df.loc[:, "factor"] = data_1d["close"] / df["close"] df.loc[:, "paused"] = data_1d["paused"] @@ -580,12 +636,16 @@ class YahooNormalize1min(YahooNormalize, ABC): df.set_index(self._date_field_name, inplace=True) df = df.reindex( self.generate_1min_from_daily( - pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) + pd.to_datetime( + data_1d.reset_index()[ + self._date_field_name + ].drop_duplicates() + ) ) ) - df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ] + df[self._symbol_field_name] = df.loc[ + df[self._symbol_field_name].first_valid_index() + ][self._symbol_field_name] df.index.names = [self._date_field_name] df.reset_index(inplace=True) for _col in self.COLUMNS: @@ -603,7 +663,9 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["_tmp_date"] = df[self._date_field_name].apply( + lambda x: pd.Timestamp(x).date() + ) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan @@ -623,7 +685,10 @@ class YahooNormalize1min(YahooNormalize, ABC): self._date_field_name, self._symbol_field_name, } - if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): + if ( + _df.loc[:, check_fields].isna().values.all() + or (_df["volume"] == 0).all() + ): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1 @@ -658,7 +723,11 @@ class YahooNormalize1minOffline(YahooNormalize1min): """Normalised to 1min using local 1d data""" def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + qlib_data_1d_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -672,7 +741,9 @@ class YahooNormalize1minOffline(YahooNormalize1min): symbol field name, default is symbol """ self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) + super(YahooNormalize1minOffline, self).__init__( + date_field_name, symbol_field_name + ) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -687,9 +758,19 @@ class YahooNormalize1minOffline(YahooNormalize1min): from qlib.data import D qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + df = D.features( + D.instruments("all"), + ["$paused", "$volume", "$factor", "$close"], + freq="day", + ) df.reset_index(inplace=True) - df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) + df.rename( + columns={ + "datetime": self._date_field_name, + "instrument": self._symbol_field_name, + }, + inplace=True, + ) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -757,7 +838,11 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): def symbol_to_yahoo(self, symbol): if "." not in symbol: _exchange = symbol[:2] - _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange + _exchange = ( + ("ss" if _exchange.islower() else "SS") + if _exchange.lower() == "sh" + else _exchange + ) symbol = symbol[2:] + "." + _exchange return symbol @@ -766,7 +851,14 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): + def __init__( + self, + source_dir=None, + normalize_dir=None, + max_workers=1, + interval="1d", + region=REGION_CN, + ): """ Parameters @@ -838,7 +930,13 @@ class Run(BaseRun): $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ super(Run, self).download_data( - max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums + max_collector_count, + delay, + start, + end, + self.interval, + check_data_length, + limit_nums, ) def normalize_data( @@ -873,16 +971,25 @@ class Run(BaseRun): $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ if self.interval.lower() == "1min": - if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): + if ( + qlib_data_1d_dir is None + or not Path(qlib_data_1d_dir).expanduser().exists() + ): raise ValueError( "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + date_field_name, + symbol_field_name, + end_date=end_date, + qlib_data_1d_dir=qlib_data_1d_dir, ) def normalize_data_1d_extend( - self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + self, + old_qlib_data_dir, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) @@ -1013,19 +1120,30 @@ class Run(BaseRun): # start/end date if trading_date is None: trading_date = datetime.datetime.now().strftime("%Y-%m-%d") - logger.warning(f"trading_date is None, use the current date: {trading_date}") + logger.warning( + f"trading_date is None, use the current date: {trading_date}" + ) if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) + GetData().qlib_data( + target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region + ) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + self.download_data( + delay=delay, + start=trading_date, + end=end_date, + check_data_length=check_data_length, + ) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) @@ -1047,11 +1165,18 @@ class Run(BaseRun): # parse index _region = self.region.lower() if _region not in ["cn", "us"]: - logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") + logger.warning( + f"Unsupported region: region={_region}, component downloads will be ignored" + ) return - index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] + index_list = ( + ["CSI100", "CSI300"] + if _region == "cn" + else ["SP500", "NASDAQ100", "DJIA", "SP400"] + ) get_instruments = getattr( - importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" + importlib.import_module(f"data_collector.{_region}_index.collector"), + "get_instruments", ) for _index in index_list: get_instruments(str(qlib_data_1d_dir), _index)