From 3c740fc2def966772b117523eec19597ec8b2f04 Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Wed, 14 Jul 2021 19:54:55 +0530 Subject: [PATCH 01/73] MVP for Indian Stocks in qlib using yahooquery --- scripts/data_collector/utils.py | 42 +++++++++++++++++++++++ scripts/data_collector/yahoo/collector.py | 23 +++++++++++++ 2 files changed, 65 insertions(+) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1a8d479d9..72bd1be18 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -39,6 +39,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_IN_SYMBOLS = None _EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} @@ -298,6 +299,47 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list: + """get IN stock symbols + + Returns + ------- + stock symbols + """ + global _IN_SYMBOLS + + @deco_retry + def _get_nifty(): + url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv" + df = pd.read_csv(url) + df = df.rename(columns={"SYMBOL": "Symbol"}) + df['Symbol'] = df['Symbol'] + ".NS" + _symbols = df["Symbol"].dropna() + _symbols = _symbols.unique().tolist() + return _symbols + + if _IN_SYMBOLS is None: + _all_symbols = _get_nifty() + if qlib_data_path is not None: + for _index in ["nifty"]: + ins_df = pd.read_csv( + Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"), + sep="\t", + names=["symbol", "start_date", "end_date"], + ) + _all_symbols += ins_df["symbol"].unique().tolist() + + def _format(s_): + s_ = s_.replace(".", "-") + s_ = s_.strip("$") + s_ = s_.strip("*") + return s_ + + _IN_SYMBOLS = sorted(set(_all_symbols)) + + return _IN_SYMBOLS + + def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: """get en fund symbols diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 6a128a5be..e262dac19 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -34,6 +34,7 @@ from data_collector.utils import ( get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols, + get_in_stock_symbols, generate_minutes_calendar_from_daily, ) @@ -279,6 +280,28 @@ class YahooCollectorUS1min(YahooCollectorUS): pass +class YahooCollectorIN(YahooCollector, ABC): + def get_instrument_list(self): + logger.info("get INDIA stock symbols......") + symbols = get_in_stock_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def download_index_data(self): + pass + + def normalize_symbol(self, symbol): + return code_to_fname(symbol).upper() + + @property + def _timezone(self): + return "Asia/Kolkata" + + +class YahooCollectorIN1d(YahooCollectorIN): + pass + + class YahooNormalize(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] DAILY_FORMAT = "%Y-%m-%d" From 457dcaa4667970f6a02163c32f64d008d28edca6 Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Wed, 14 Jul 2021 20:12:00 +0530 Subject: [PATCH 02/73] cleaned with black --- scripts/data_collector/utils.py | 57 ++++-- scripts/data_collector/yahoo/collector.py | 219 +++++++++++++++++----- 2 files changed, 217 insertions(+), 59 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 72bd1be18..62d6e4eb1 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -69,8 +69,15 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: if bench_code.startswith("US_"): - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") - calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + calendar = ( + df.index.get_level_values(level="date") + .map(pd.Timestamp) + .unique() + .tolist() + ) else: if bench_code.upper() == "ALL": @@ -78,7 +85,9 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: def _get_calendar(month): _cal = [] try: - resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json() + resp = requests.get( + SZSE_CALENDAR_URL.format(month=month, random=random.random) + ).json() for _r in resp["data"]: if int(_r["jybz"]): _cal.append(pd.Timestamp(_r["jyrq"])) @@ -86,7 +95,11 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: raise ValueError(f"{month}-->{e}") return _cal - month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") + month_range = pd.date_range( + start="2000-01", + end=pd.Timestamp.now() + pd.Timedelta(days=31), + freq="M", + ) calendar = [] for _m in month_range: cal = _get_calendar(_m.strftime("%Y-%m")) @@ -156,7 +169,9 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + _dict_count_founding = { + date: _number_all_funds for date in _dict_count_trade.keys() + } # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -166,7 +181,8 @@ def get_calendar_list_by_ratio( calendar = [ date for date in _dict_count_trade - if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + if _dict_count_trade[date] + >= max(int(_dict_count_founding[date] * threshold), minimum_count) ] return calendar @@ -188,7 +204,9 @@ def get_hs_stock_symbols() -> list: _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), - etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), + etree.HTML(resp.text).xpath( + "//div[@class='result']/ul//li/a/text()" + ), ) ) time.sleep(3) @@ -232,7 +250,10 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: if resp.status_code != 200: raise ValueError("request error") try: - _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + _symbols = [ + _v["f12"].replace("_", "-P") + for _v in resp.json()["data"]["diff"].values() + ] except Exception as e: logger.warning(f"request error: {e}") raise @@ -294,7 +315,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) + _US_SYMBOLS = sorted( + set( + map( + _format, + filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), + ) + ) + ) return _US_SYMBOLS @@ -313,7 +341,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list: url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv" df = pd.read_csv(url) df = df.rename(columns={"SYMBOL": "Symbol"}) - df['Symbol'] = df['Symbol'] + ".NS" + df["Symbol"] = df["Symbol"] + ".NS" _symbols = df["Symbol"].dropna() _symbols = _symbols.unique().tolist() return _symbols @@ -357,7 +385,10 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + for sub_data in re.findall( + r"[\[](.*?)[\]]", + resp.content.decode().split("= [")[-1].replace("];", ""), + ): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -436,7 +467,9 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): +def get_trading_date_by_shift( + trading_list: list, trading_date: pd.Timestamp, shift: int = 1 +): """get trading date by shift Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index e262dac19..cf97a0c7e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -93,7 +93,9 @@ class YahooCollector(BaseCollector): def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + self.start_datetime = max( + self.start_datetime, self.DEFAULT_START_DATETIME_1MIN + ) elif self.interval == self.INTERVAL_1d: pass else: @@ -117,7 +119,9 @@ class YahooCollector(BaseCollector): raise NotImplementedError("rewrite get_timezone") @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + def get_data_from_remote( + symbol, interval, start, end, show_1min_logging: bool = False + ): error_msg = f"{symbol}-{interval}-{start}-{end}" def _show_logging_func(): @@ -126,13 +130,16 @@ class YahooCollector(BaseCollector): interval = "1m" if interval in ["1m", "1min"] else interval try: - _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) + _resp = Ticker(symbol, asynchronous=False).history( + interval=interval, start=start, end=end + ) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() elif isinstance(_resp, dict): _temp_data = _resp.get(symbol, {}) if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None + isinstance(_resp, dict) + and _temp_data.get("indicators", {}).get("quote", None) is None ): _show_logging_func() else: @@ -141,7 +148,11 @@ class YahooCollector(BaseCollector): logger.warning(f"{error_msg}:{e}") def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: @deco_retry(retry_sleep=self.delay) def _get_simple(start_, end_): @@ -214,21 +225,35 @@ class YahooCollectorCN1d(YahooCollectorCN): _format = "%Y%m%d" _begin = self.start_datetime.strftime(_format) _end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format) - for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): + for _index_name, _index_code in { + "csi300": "000300", + "csi100": "000903", + }.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: df = pd.DataFrame( map( lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ - "data" - ]["klines"], + requests.get( + INDEX_BENCH_URL.format( + index_code=_index_code, begin=_begin, end=_end + ) + ).json()["data"]["klines"], ) ) except Exception as e: logger.warning(f"get {_index_name} error: {e}") continue - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df.columns = [ + "date", + "open", + "close", + "high", + "low", + "volume", + "money", + "change", + ] df["date"] = pd.to_datetime(df["date"]) df = df.astype(float, errors="ignore") df["adjclose"] = df["close"] @@ -283,7 +308,7 @@ class YahooCollectorUS1min(YahooCollectorUS): class YahooCollectorIN(YahooCollector, ABC): def get_instrument_list(self): logger.info("get INDIA stock symbols......") - symbols = get_in_stock_symbols() + symbols = get_in_stock_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols @@ -326,13 +351,18 @@ class YahooNormalize(BaseNormalize): df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan + df.loc[ + (df["volume"] <= 0) | np.isnan(df["volume"]), + set(df.columns) - {symbol_field_name}, + ] = np.nan _tmp_series = df["close"].fillna(method="ffill") _tmp_shift_series = _tmp_series.shift(1) if last_close is not None: @@ -347,7 +377,9 @@ class YahooNormalize(BaseNormalize): def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_yahoo( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) # adjusted price df = self.adjusted_price(df) return df @@ -418,7 +450,11 @@ class YahooNormalize1d(YahooNormalize, ABC): class YahooNormalize1dExtend(YahooNormalize1d): def __init__( - self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + old_qlib_data_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -447,7 +483,9 @@ class YahooNormalize1dExtend(YahooNormalize1d): return df def _get_close(self, df: pd.DataFrame, field_name: str): - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ].upper() _df = self.old_qlib_data.loc(axis=0)[_symbol] _close = _df.loc[_df.last_valid_index()][field_name] return _close @@ -467,7 +505,9 @@ class YahooNormalize1dExtend(YahooNormalize1d): return _close def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ].upper() try: _df = self.old_qlib_data.loc(axis=0)[_symbol] _date = _df.index.max() @@ -495,7 +535,11 @@ class YahooNormalize1dExtend(YahooNormalize1d): ) # normalize df = self.normalize_yahoo( - df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close + df, + self._calendar_list, + self._date_field_name, + self._symbol_field_name, + last_close=_last_close, ) # adjusted price df = self.adjusted_price(df) @@ -533,10 +577,14 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] """ - data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) + data_1d = YahooCollector.get_data_from_remote( + self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end + ) if not (data_1d is None or data_1d.empty): _class_name = self.__class__.__name__.replace("min", "d") - _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) + _class: type(YahooNormalize) = getattr( + importlib.import_module("collector"), _class_name + ) data_1d_obj = _class(self._date_field_name, self._symbol_field_name) data_1d = data_1d_obj.normalize(data_1d) return data_1d @@ -549,8 +597,12 @@ class YahooNormalize1min(YahooNormalize, ABC): df = df.sort_values(self._date_field_name) symbol = df.iloc[0][self._symbol_field_name] # get 1d data from yahoo - _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) - _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) + _start = pd.Timestamp(df[self._date_field_name].min()).strftime( + self.DAILY_FORMAT + ) + _end = ( + pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1) + ).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: @@ -561,7 +613,9 @@ class YahooNormalize1min(YahooNormalize, ABC): # NOTE: volume is np.nan or volume <= 0, paused = 1 # FIXME: find a more accurate data source data_1d["paused"] = 0 - data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 + data_1d.loc[ + (data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused" + ] = 1 data_1d = data_1d.set_index(self._date_field_name) # add factor from 1d data @@ -569,7 +623,9 @@ class YahooNormalize1min(YahooNormalize, ABC): # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["date_tmp"] = df[self._date_field_name].apply( + lambda x: pd.Timestamp(x).date() + ) df.set_index("date_tmp", inplace=True) df.loc[:, "factor"] = data_1d["close"] / df["close"] df.loc[:, "paused"] = data_1d["paused"] @@ -580,12 +636,16 @@ class YahooNormalize1min(YahooNormalize, ABC): df.set_index(self._date_field_name, inplace=True) df = df.reindex( self.generate_1min_from_daily( - pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) + pd.to_datetime( + data_1d.reset_index()[ + self._date_field_name + ].drop_duplicates() + ) ) ) - df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ] + df[self._symbol_field_name] = df.loc[ + df[self._symbol_field_name].first_valid_index() + ][self._symbol_field_name] df.index.names = [self._date_field_name] df.reset_index(inplace=True) for _col in self.COLUMNS: @@ -603,7 +663,9 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["_tmp_date"] = df[self._date_field_name].apply( + lambda x: pd.Timestamp(x).date() + ) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan @@ -623,7 +685,10 @@ class YahooNormalize1min(YahooNormalize, ABC): self._date_field_name, self._symbol_field_name, } - if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): + if ( + _df.loc[:, check_fields].isna().values.all() + or (_df["volume"] == 0).all() + ): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1 @@ -658,7 +723,11 @@ class YahooNormalize1minOffline(YahooNormalize1min): """Normalised to 1min using local 1d data""" def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + qlib_data_1d_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -672,7 +741,9 @@ class YahooNormalize1minOffline(YahooNormalize1min): symbol field name, default is symbol """ self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) + super(YahooNormalize1minOffline, self).__init__( + date_field_name, symbol_field_name + ) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -687,9 +758,19 @@ class YahooNormalize1minOffline(YahooNormalize1min): from qlib.data import D qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + df = D.features( + D.instruments("all"), + ["$paused", "$volume", "$factor", "$close"], + freq="day", + ) df.reset_index(inplace=True) - df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) + df.rename( + columns={ + "datetime": self._date_field_name, + "instrument": self._symbol_field_name, + }, + inplace=True, + ) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -757,7 +838,11 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): def symbol_to_yahoo(self, symbol): if "." not in symbol: _exchange = symbol[:2] - _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange + _exchange = ( + ("ss" if _exchange.islower() else "SS") + if _exchange.lower() == "sh" + else _exchange + ) symbol = symbol[2:] + "." + _exchange return symbol @@ -766,7 +851,14 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): + def __init__( + self, + source_dir=None, + normalize_dir=None, + max_workers=1, + interval="1d", + region=REGION_CN, + ): """ Parameters @@ -838,7 +930,13 @@ class Run(BaseRun): $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ super(Run, self).download_data( - max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums + max_collector_count, + delay, + start, + end, + self.interval, + check_data_length, + limit_nums, ) def normalize_data( @@ -873,16 +971,25 @@ class Run(BaseRun): $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ if self.interval.lower() == "1min": - if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): + if ( + qlib_data_1d_dir is None + or not Path(qlib_data_1d_dir).expanduser().exists() + ): raise ValueError( "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + date_field_name, + symbol_field_name, + end_date=end_date, + qlib_data_1d_dir=qlib_data_1d_dir, ) def normalize_data_1d_extend( - self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + self, + old_qlib_data_dir, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) @@ -1013,19 +1120,30 @@ class Run(BaseRun): # start/end date if trading_date is None: trading_date = datetime.datetime.now().strftime("%Y-%m-%d") - logger.warning(f"trading_date is None, use the current date: {trading_date}") + logger.warning( + f"trading_date is None, use the current date: {trading_date}" + ) if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) + GetData().qlib_data( + target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region + ) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + self.download_data( + delay=delay, + start=trading_date, + end=end_date, + check_data_length=check_data_length, + ) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) @@ -1047,11 +1165,18 @@ class Run(BaseRun): # parse index _region = self.region.lower() if _region not in ["cn", "us"]: - logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") + logger.warning( + f"Unsupported region: region={_region}, component downloads will be ignored" + ) return - index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] + index_list = ( + ["CSI100", "CSI300"] + if _region == "cn" + else ["SP500", "NASDAQ100", "DJIA", "SP400"] + ) get_instruments = getattr( - importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" + importlib.import_module(f"data_collector.{_region}_index.collector"), + "get_instruments", ) for _index in index_list: get_instruments(str(qlib_data_1d_dir), _index) From cfcd9fb1f8905fb07f76b30978e6855db3ebd3e3 Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Thu, 15 Jul 2021 11:24:41 +0530 Subject: [PATCH 03/73] cleaned with black --- scripts/data_collector/utils.py | 55 ++---- scripts/data_collector/yahoo/collector.py | 217 +++++----------------- 2 files changed, 57 insertions(+), 215 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 62d6e4eb1..ceb0735cb 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -69,15 +69,8 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: if bench_code.startswith("US_"): - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( - interval="1d", period="max" - ) - calendar = ( - df.index.get_level_values(level="date") - .map(pd.Timestamp) - .unique() - .tolist() - ) + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") + calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() else: if bench_code.upper() == "ALL": @@ -85,9 +78,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: def _get_calendar(month): _cal = [] try: - resp = requests.get( - SZSE_CALENDAR_URL.format(month=month, random=random.random) - ).json() + resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json() for _r in resp["data"]: if int(_r["jybz"]): _cal.append(pd.Timestamp(_r["jyrq"])) @@ -95,11 +86,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: raise ValueError(f"{month}-->{e}") return _cal - month_range = pd.date_range( - start="2000-01", - end=pd.Timestamp.now() + pd.Timedelta(days=31), - freq="M", - ) + month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") calendar = [] for _m in month_range: cal = _get_calendar(_m.strftime("%Y-%m")) @@ -169,9 +156,7 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = { - date: _number_all_funds for date in _dict_count_trade.keys() - } # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -181,8 +166,7 @@ def get_calendar_list_by_ratio( calendar = [ date for date in _dict_count_trade - if _dict_count_trade[date] - >= max(int(_dict_count_founding[date] * threshold), minimum_count) + if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) ] return calendar @@ -204,9 +188,7 @@ def get_hs_stock_symbols() -> list: _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), - etree.HTML(resp.text).xpath( - "//div[@class='result']/ul//li/a/text()" - ), + etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), ) ) time.sleep(3) @@ -250,10 +232,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: if resp.status_code != 200: raise ValueError("request error") try: - _symbols = [ - _v["f12"].replace("_", "-P") - for _v in resp.json()["data"]["diff"].values() - ] + _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] except Exception as e: logger.warning(f"request error: {e}") raise @@ -315,14 +294,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted( - set( - map( - _format, - filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), - ) - ) - ) + _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) return _US_SYMBOLS @@ -385,10 +357,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall( - r"[\[](.*?)[\]]", - resp.content.decode().split("= [")[-1].replace("];", ""), - ): + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -467,9 +436,7 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift( - trading_list: list, trading_date: pd.Timestamp, shift: int = 1 -): +def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): """get trading date by shift Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index cf97a0c7e..d7518eca8 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -93,9 +93,7 @@ class YahooCollector(BaseCollector): def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max( - self.start_datetime, self.DEFAULT_START_DATETIME_1MIN - ) + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) elif self.interval == self.INTERVAL_1d: pass else: @@ -119,9 +117,7 @@ class YahooCollector(BaseCollector): raise NotImplementedError("rewrite get_timezone") @staticmethod - def get_data_from_remote( - symbol, interval, start, end, show_1min_logging: bool = False - ): + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): error_msg = f"{symbol}-{interval}-{start}-{end}" def _show_logging_func(): @@ -130,16 +126,13 @@ class YahooCollector(BaseCollector): interval = "1m" if interval in ["1m", "1min"] else interval try: - _resp = Ticker(symbol, asynchronous=False).history( - interval=interval, start=start, end=end - ) + _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() elif isinstance(_resp, dict): _temp_data = _resp.get(symbol, {}) if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) - and _temp_data.get("indicators", {}).get("quote", None) is None + isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None ): _show_logging_func() else: @@ -148,11 +141,7 @@ class YahooCollector(BaseCollector): logger.warning(f"{error_msg}:{e}") def get_data( - self, - symbol: str, - interval: str, - start_datetime: pd.Timestamp, - end_datetime: pd.Timestamp, + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> pd.DataFrame: @deco_retry(retry_sleep=self.delay) def _get_simple(start_, end_): @@ -225,35 +214,21 @@ class YahooCollectorCN1d(YahooCollectorCN): _format = "%Y%m%d" _begin = self.start_datetime.strftime(_format) _end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format) - for _index_name, _index_code in { - "csi300": "000300", - "csi100": "000903", - }.items(): + for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: df = pd.DataFrame( map( lambda x: x.split(","), - requests.get( - INDEX_BENCH_URL.format( - index_code=_index_code, begin=_begin, end=_end - ) - ).json()["data"]["klines"], + requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ + "data" + ]["klines"], ) ) except Exception as e: logger.warning(f"get {_index_name} error: {e}") continue - df.columns = [ - "date", - "open", - "close", - "high", - "low", - "volume", - "money", - "change", - ] + df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] df["date"] = pd.to_datetime(df["date"]) df = df.astype(float, errors="ignore") df["adjclose"] = df["close"] @@ -351,18 +326,13 @@ class YahooNormalize(BaseNormalize): df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()) - .date() : pd.Timestamp(df.index.max()) - .date() + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) - df.loc[ - (df["volume"] <= 0) | np.isnan(df["volume"]), - set(df.columns) - {symbol_field_name}, - ] = np.nan + df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan _tmp_series = df["close"].fillna(method="ffill") _tmp_shift_series = _tmp_series.shift(1) if last_close is not None: @@ -377,9 +347,7 @@ class YahooNormalize(BaseNormalize): def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_yahoo( - df, self._calendar_list, self._date_field_name, self._symbol_field_name - ) + df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name) # adjusted price df = self.adjusted_price(df) return df @@ -450,11 +418,7 @@ class YahooNormalize1d(YahooNormalize, ABC): class YahooNormalize1dExtend(YahooNormalize1d): def __init__( - self, - old_qlib_data_dir: [str, Path], - date_field_name: str = "date", - symbol_field_name: str = "symbol", - **kwargs, + self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs ): """ @@ -483,9 +447,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): return df def _get_close(self, df: pd.DataFrame, field_name: str): - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() _df = self.old_qlib_data.loc(axis=0)[_symbol] _close = _df.loc[_df.last_valid_index()][field_name] return _close @@ -505,9 +467,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): return _close def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ].upper() + _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() try: _df = self.old_qlib_data.loc(axis=0)[_symbol] _date = _df.index.max() @@ -535,11 +495,7 @@ class YahooNormalize1dExtend(YahooNormalize1d): ) # normalize df = self.normalize_yahoo( - df, - self._calendar_list, - self._date_field_name, - self._symbol_field_name, - last_close=_last_close, + df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close ) # adjusted price df = self.adjusted_price(df) @@ -577,14 +533,10 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] """ - data_1d = YahooCollector.get_data_from_remote( - self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end - ) + data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) if not (data_1d is None or data_1d.empty): _class_name = self.__class__.__name__.replace("min", "d") - _class: type(YahooNormalize) = getattr( - importlib.import_module("collector"), _class_name - ) + _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) data_1d_obj = _class(self._date_field_name, self._symbol_field_name) data_1d = data_1d_obj.normalize(data_1d) return data_1d @@ -597,12 +549,8 @@ class YahooNormalize1min(YahooNormalize, ABC): df = df.sort_values(self._date_field_name) symbol = df.iloc[0][self._symbol_field_name] # get 1d data from yahoo - _start = pd.Timestamp(df[self._date_field_name].min()).strftime( - self.DAILY_FORMAT - ) - _end = ( - pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1) - ).strftime(self.DAILY_FORMAT) + _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) + _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: @@ -613,9 +561,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # NOTE: volume is np.nan or volume <= 0, paused = 1 # FIXME: find a more accurate data source data_1d["paused"] = 0 - data_1d.loc[ - (data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused" - ] = 1 + data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 data_1d = data_1d.set_index(self._date_field_name) # add factor from 1d data @@ -623,9 +569,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - df["date_tmp"] = df[self._date_field_name].apply( - lambda x: pd.Timestamp(x).date() - ) + df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) df.set_index("date_tmp", inplace=True) df.loc[:, "factor"] = data_1d["close"] / df["close"] df.loc[:, "paused"] = data_1d["paused"] @@ -636,16 +580,12 @@ class YahooNormalize1min(YahooNormalize, ABC): df.set_index(self._date_field_name, inplace=True) df = df.reindex( self.generate_1min_from_daily( - pd.to_datetime( - data_1d.reset_index()[ - self._date_field_name - ].drop_duplicates() - ) + pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) ) ) - df[self._symbol_field_name] = df.loc[ - df[self._symbol_field_name].first_valid_index() - ][self._symbol_field_name] + df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ] df.index.names = [self._date_field_name] df.reset_index(inplace=True) for _col in self.COLUMNS: @@ -663,9 +603,7 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply( - lambda x: pd.Timestamp(x).date() - ) + df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan @@ -685,10 +623,7 @@ class YahooNormalize1min(YahooNormalize, ABC): self._date_field_name, self._symbol_field_name, } - if ( - _df.loc[:, check_fields].isna().values.all() - or (_df["volume"] == 0).all() - ): + if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1 @@ -723,11 +658,7 @@ class YahooNormalize1minOffline(YahooNormalize1min): """Normalised to 1min using local 1d data""" def __init__( - self, - qlib_data_1d_dir: [str, Path], - date_field_name: str = "date", - symbol_field_name: str = "symbol", - **kwargs, + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs ): """ @@ -741,9 +672,7 @@ class YahooNormalize1minOffline(YahooNormalize1min): symbol field name, default is symbol """ self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalize1minOffline, self).__init__( - date_field_name, symbol_field_name - ) + super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -758,19 +687,9 @@ class YahooNormalize1minOffline(YahooNormalize1min): from qlib.data import D qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features( - D.instruments("all"), - ["$paused", "$volume", "$factor", "$close"], - freq="day", - ) + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") df.reset_index(inplace=True) - df.rename( - columns={ - "datetime": self._date_field_name, - "instrument": self._symbol_field_name, - }, - inplace=True, - ) + df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -838,11 +757,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): def symbol_to_yahoo(self, symbol): if "." not in symbol: _exchange = symbol[:2] - _exchange = ( - ("ss" if _exchange.islower() else "SS") - if _exchange.lower() == "sh" - else _exchange - ) + _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange symbol = symbol[2:] + "." + _exchange return symbol @@ -851,14 +766,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): class Run(BaseRun): - def __init__( - self, - source_dir=None, - normalize_dir=None, - max_workers=1, - interval="1d", - region=REGION_CN, - ): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ Parameters @@ -930,13 +838,7 @@ class Run(BaseRun): $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ super(Run, self).download_data( - max_collector_count, - delay, - start, - end, - self.interval, - check_data_length, - limit_nums, + max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums ) def normalize_data( @@ -971,25 +873,16 @@ class Run(BaseRun): $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ if self.interval.lower() == "1min": - if ( - qlib_data_1d_dir is None - or not Path(qlib_data_1d_dir).expanduser().exists() - ): + if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): raise ValueError( "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, - symbol_field_name, - end_date=end_date, - qlib_data_1d_dir=qlib_data_1d_dir, + date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir ) def normalize_data_1d_extend( - self, - old_qlib_data_dir, - date_field_name: str = "date", - symbol_field_name: str = "symbol", + self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" ): """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) @@ -1120,30 +1013,19 @@ class Run(BaseRun): # start/end date if trading_date is None: trading_date = datetime.datetime.now().strftime("%Y-%m-%d") - logger.warning( - f"trading_date is None, use the current date: {trading_date}" - ) + logger.warning(f"trading_date is None, use the current date: {trading_date}") if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime( - "%Y-%m-%d" - ) + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data( - target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region - ) + GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data( - delay=delay, - start=trading_date, - end=end_date, - check_data_length=check_data_length, - ) + self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) @@ -1165,18 +1047,11 @@ class Run(BaseRun): # parse index _region = self.region.lower() if _region not in ["cn", "us"]: - logger.warning( - f"Unsupported region: region={_region}, component downloads will be ignored" - ) + logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") return - index_list = ( - ["CSI100", "CSI300"] - if _region == "cn" - else ["SP500", "NASDAQ100", "DJIA", "SP400"] - ) + index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] get_instruments = getattr( - importlib.import_module(f"data_collector.{_region}_index.collector"), - "get_instruments", + importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" ) for _index in index_list: get_instruments(str(qlib_data_1d_dir), _index) From d70e5a4f883896c9ce5061d11b6b81e379a3af36 Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Sat, 17 Jul 2021 10:40:16 +0530 Subject: [PATCH 04/73] add YahooNormalizeIN and YahooNormalizeIN1d --- scripts/data_collector/utils.py | 3 ++- scripts/data_collector/yahoo/collector.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index ceb0735cb..883a1c551 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -32,6 +32,7 @@ CALENDAR_BENCH_URL_MAP = { "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks "US_ALL": "^GSPC", + "IN_ALL": "^NSEI", } @@ -68,7 +69,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - if bench_code.startswith("US_"): + if bench_code.startswith("US_") or bench_code.startswith("IN_"): df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() else: diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index d7518eca8..97e674293 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -733,6 +733,16 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline): return fname_to_code(symbol) +class YahooNormalizeIN: + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: + # TODO: from MSN + return get_calendar_list("IN_ALL") + + +class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d): + pass + + class YahooNormalizeCN: def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN From d1c8d885aa9309b884dcdc377cdbcf7b21a40e7c Mon Sep 17 00:00:00 2001 From: Gaurav <2796gaurav@gmail.com> Date: Wed, 21 Jul 2021 17:59:50 +0530 Subject: [PATCH 05/73] cleaned the code --- scripts/data_collector/yahoo/collector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 97e674293..464e9e516 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -735,7 +735,6 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline): class YahooNormalizeIN: def _get_calendar_list(self) -> Iterable[pd.Timestamp]: - # TODO: from MSN return get_calendar_list("IN_ALL") From 8fa22bd2e151545d7a5ea67ea63504b94cb69dc9 Mon Sep 17 00:00:00 2001 From: 2796gaurav <17353992+2796gaurav@users.noreply.github.com> Date: Wed, 21 Jul 2021 14:16:22 +0530 Subject: [PATCH 06/73] added 1min for IN and also updated readme --- scripts/data_collector/yahoo/README.md | 14 +++++++++++--- scripts/data_collector/yahoo/collector.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 50f731e38..0deef9c95 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -37,7 +37,7 @@ pip install -r requirements.txt - user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance) - **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`* - `interval`: `1d` or `1min`, by default `1d` - - `region`: `cn` or `us`, by default `cn` + - `region`: `cn` or `us` or `in`, by default `cn` - `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True` - `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False` - examples: @@ -50,6 +50,10 @@ pip install -r requirements.txt python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d # us 1min python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min + # in 1d + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1d --region in --interval 1d + # in 1min + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1min --region in --interval 1min ``` ### Collector *YahooFinance* data to qlib @@ -60,7 +64,7 @@ pip install -r requirements.txt - `source_dir`: save the directory - `interval`: `1d` or `1min`, by default `1d` > **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`** - - `region`: `CN` or `US`, by default `CN` + - `region`: `CN` or `US` or `IN`, by default `CN` - `delay`: `time.sleep(delay)`, by default *0.5* - `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)* - `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)* @@ -78,6 +82,10 @@ pip install -r requirements.txt python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US # us 1min data python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US + # in 1d data + python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region IN + # in 1min data + python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1min --delay 1 --interval 1min --region IN ``` 2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data` @@ -87,7 +95,7 @@ pip install -r requirements.txt - `max_workers`: number of concurrent, by default *1* - `interval`: `1d` or `1min`, by default `1d` > if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None` - - `region`: `CN` or `US`, by default `CN` + - `region`: `CN` or `US` or `IN`, by default `CN` - `date_field_name`: column *name* identifying time in csv files, by default `date` - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol` - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None` diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 464e9e516..defa86198 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -302,6 +302,10 @@ class YahooCollectorIN1d(YahooCollectorIN): pass +class YahooCollectorIN1min(YahooCollectorIN): + pass + + class YahooNormalize(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] DAILY_FORMAT = "%Y-%m-%d" @@ -742,6 +746,20 @@ class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d): pass +class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline): + CALC_PAUSED_NUM = False + + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: + # TODO: support 1min + raise ValueError("Does not support 1min") + + def _get_1d_calendar_list(self): + return get_calendar_list("IN_ALL") + + def symbol_to_yahoo(self, symbol): + return fname_to_code(symbol) + + class YahooNormalizeCN: def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN From 1d22ee56d30aa8ddab2dadb1fe8e242a777b215d Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 25 Jul 2021 16:35:16 +0000 Subject: [PATCH 07/73] recorder support upload both raw file and directory --- qlib/workflow/recorder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 7edb0ebb9..1b391cbe2 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -299,7 +299,11 @@ class MLflowRecorder(Recorder): def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." if local_path is not None: - self.client.log_artifacts(self.id, local_path, artifact_path) + path = Path(local_path) + if path.is_dir(): + self.client.log_artifacts(self.id, local_path, artifact_path) + else: + self.client.log_artifact(self.id, local_path, artifact_path) else: temp_dir = Path(tempfile.mkdtemp()).resolve() for name, data in kwargs.items(): From a6f9dde0067491b3273607c3dbefa515a0f356b0 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Mon, 26 Jul 2021 18:36:09 +0800 Subject: [PATCH 08/73] Update README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 422046c13..2c64ad7d2 100644 --- a/README.md +++ b/README.md @@ -374,9 +374,7 @@ Such overheads greatly slow down the data loading process. Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation. # Related Reports -- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA) - [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/) -- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ) - [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ) - [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ) From dc6859bdd9fa268e6564a911cf7ec738f1c8bb7a Mon Sep 17 00:00:00 2001 From: you-n-g Date: Mon, 26 Jul 2021 19:00:47 +0800 Subject: [PATCH 09/73] Fix docs of QlibRecorder --- qlib/workflow/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 36d0f464e..51a6ed553 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -38,13 +38,13 @@ class QlibRecorder: .. code-block:: Python # start new experiment and recorder - with R.start('test', 'recorder_1'): + with R.start(experiment_name='test', recorder_name='recorder_1'): model.fit(dataset) R.log... ... # further operations # resume previous experiment and recorder - with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder. + with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder. ... # further operations Parameters From 05d28469ad49e3fa833b67ba0b6264573a8a5aed Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 29 Jul 2021 12:06:59 +0800 Subject: [PATCH 10/73] sort index after loader (#538) make sure the fetch method is based on a index-sorted pd.DataFrame --- qlib/data/dataset/handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index c6338832a..b823728fb 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -18,6 +18,7 @@ from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable from .utils import fetch_df_by_index +from ...utils import lazy_sort_index from pathlib import Path from .loader import DataLoader @@ -146,7 +147,8 @@ class DataHandler(Serializable): # Setup data. # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): - self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) + # make sure the fetch method is based on a index-sorted pd.DataFrame + self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time)) # TODO: cache CS_ALL = "__all" # return all columns with single-level index column From 9303415666ec0490a77e210e3387526bc9d8f84a Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 29 Jul 2021 04:40:27 +0000 Subject: [PATCH 11/73] refactor online serving rolling api --- ..._config_lightgbm_configurable_dataset.yaml | 2 +- examples/model_rolling/requirements.txt | 1 + qlib/utils/__init__.py | 4 +- qlib/workflow/online/strategy.py | 24 ++--- qlib/workflow/online/update.py | 2 + qlib/workflow/task/gen.py | 102 +++++++++++------- 6 files changed, 81 insertions(+), 54 deletions(-) create mode 100644 examples/model_rolling/requirements.txt diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml index 335dc2093..78f567eb3 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml @@ -78,4 +78,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/model_rolling/requirements.txt b/examples/model_rolling/requirements.txt new file mode 100644 index 000000000..10ddd5b71 --- /dev/null +++ b/examples/model_rolling/requirements.txt @@ -0,0 +1 @@ +xgboost diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 778d0e17a..2fe9eafed 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False): def transform_end_date(end_date=None, freq="day"): - """get previous trading date + """handle the end date with various format + If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned. Otherwise, returns the end_date + ---------- end_date: str end trading date diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 1e8e85c0f..7a923ebad 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -10,6 +10,7 @@ from typing import List, Tuple, Union from qlib.data.data import D from qlib.log import get_module_logger from qlib.model.ens.group import RollingGroup +from qlib.utils import transform_end_date from qlib.workflow.online.utils import OnlineTool, OnlineToolR from qlib.workflow.recorder import Recorder from qlib.workflow.task.collect import Collector, RecorderCollector @@ -118,6 +119,7 @@ class RollingStrategy(OnlineStrategy): task_template = [task_template] self.task_template = task_template self.rg = rolling_gen + assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen" self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() @@ -174,28 +176,20 @@ class RollingStrategy(OnlineStrategy): Returns: List[dict]: a list of new tasks. """ + # TODO: filter recorders by latest test segments is not a necessary latest_records, max_test = self._list_latest(self.tool.online_models()) if max_test is None: self.logger.warn(f"No latest online recorders, no new tasks.") return [] - calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time + calendar_latest = transform_end_date(cur_time) self.logger.info( f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" ) - if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: - old_tasks = [] - tasks_tmp = [] - for rec in latest_records: - task = rec.load_object("task") - old_tasks.append(deepcopy(task)) - test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] - # modify the test segment to generate new tasks - task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) - tasks_tmp.append(task) - new_tasks_tmp = task_generator(tasks_tmp, self.rg) - new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] - return new_tasks - return [] + res = [] + for rec in latest_records: + task = rec.load_object("task") + res.extend(self.rg.gen_following_tasks(task, calendar_latest)) + return res def _list_latest(self, rec_list: List[Recorder]): """ diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index e5dbd413e..f2135a27a 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -105,6 +105,8 @@ class PredUpdater(RecordUpdater): if to_date == None: to_date = D.calendar(freq=freq)[-1] self.to_date = pd.Timestamp(to_date) + # FIXME: it will raise error when running routine with delay trainer + # should we use another predicition updater for delay trainer? self.old_pred = record.load_object("pred.pkl") self.last_end = self.old_pred.index.get_level_values("datetime").max() diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index ca7b8ae7f..e60fa4755 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -5,6 +5,7 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp """ import abc import copy +import pandas as pd from typing import List, Union, Callable from qlib.utils import transform_end_date @@ -139,6 +140,53 @@ class RollingGen(TaskGen): self.test_key = "test" self.train_key = "train" + def _update_task_segs(self, task, segs): + # update segments of this task + task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs) + if self.ds_extra_mod_func is not None: + self.ds_extra_mod_func(task, self) + + def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]: + """ + generating following rolling tasks for `task` until test_end + + Parameters + ---------- + task : dict + Qlib task format + test_end : pd.Timestamp + the latest rolling task includes `test_end` + + Returns + ------- + List[dict]: + the following tasks of `task`(`task` itself is excluded) + """ + t = copy.deepcopy(task) + prev_seg = t["dataset"]["kwargs"]["segments"] + while True: + segments = {} + try: + for k, seg in prev_seg.items(): + # decide how to shift + # expanding only for train data, the segments size of test data and valid data won't change + if k == self.train_key and self.rtype == self.ROLL_EX: + rtype = self.ta.SHIFT_EX + else: + rtype = self.ta.SHIFT_SD + # shift the segments data + segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) + if segments[self.test_key][0] > test_end: + break + except KeyError: + # We reach the end of tasks + # No more rolling + break + + prev_seg = segments + self._update_task_segs(t, segments) + yield t + def generate(self, task: dict) -> List[dict]: """ Converting the task into a rolling task. @@ -191,43 +239,23 @@ class RollingGen(TaskGen): """ res = [] - prev_seg = None - test_end = None - while True: - t = copy.deepcopy(task) + t = copy.deepcopy(task) - # calculate segments - if prev_seg is None: - # First rolling - # 1) prepare the end point - segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) - test_end = transform_end_date(segments[self.test_key][1]) - # 2) and init test segments - test_start_idx = self.ta.align_idx(segments[self.test_key][0]) - segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) - else: - segments = {} - try: - for k, seg in prev_seg.items(): - # decide how to shift - # expanding only for train data, the segments size of test data and valid data won't change - if k == self.train_key and self.rtype == self.ROLL_EX: - rtype = self.ta.SHIFT_EX - else: - rtype = self.ta.SHIFT_SD - # shift the segments data - segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) - if segments[self.test_key][0] > test_end: - break - except KeyError: - # We reach the end of tasks - # No more rolling - break + # calculate segments - # update segments of this task - t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) - prev_seg = segments - if self.ds_extra_mod_func is not None: - self.ds_extra_mod_func(t, self) - res.append(t) + # First rolling + # 1) prepare the end point + segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) + test_end = transform_end_date(segments[self.test_key][1]) + # 2) and init test segments + test_start_idx = self.ta.align_idx(segments[self.test_key][0]) + segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) + + # update segments of this task + self._update_task_segs(t, segments) + + res.append(t) + + # Update the following rolling + res.extend(self.gen_following_tasks(t, test_end)) return res From 07655f2d5b92129181f0dfbf24ecc8c0941010cb Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Wed, 21 Jul 2021 13:19:07 +0800 Subject: [PATCH 12/73] refactor TRA --- .../TRA/workflow_config_tra_Alpha158.yaml | 125 +++ .../workflow_config_tra_Alpha158_full.yaml | 118 +++ .../TRA/workflow_config_tra_Alpha360.yaml | 119 +++ qlib/contrib/data/dataset.py | 349 ++++++++ qlib/contrib/model/pytorch_tra.py | 820 ++++++++++++++++++ qlib/data/dataset/loader.py | 4 + 6 files changed, 1535 insertions(+) create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml create mode 100644 examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml create mode 100644 qlib/contrib/data/dataset.py create mode 100644 qlib/contrib/model/pytorch_tra.py diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml new file mode 100644 index 000000000..59b1c8e73 --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -0,0 +1,125 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 20 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: output/Alpha158/router + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml new file mode 100644 index 000000000..bb49798d4 --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -0,0 +1,118 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: TPE + +model_config: &model_config + input_size: 158 + hidden_size: 256 + num_layers: 2 + use_attn: True + dropout: 0.2 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: output/Alpha158_full/router + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml new file mode 100644 index 000000000..99c7aa42e --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -0,0 +1,119 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 6 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + logdir: output/Alpha360/router + seed: 0 + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: False + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: 6 + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py new file mode 100644 index 000000000..8989a6156 --- /dev/null +++ b/qlib/contrib/data/dataset.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +import torch +import warnings +import numpy as np +import pandas as pd + +from qlib.utils import init_instance_by_config +from qlib.data.dataset import DatasetH, DataHandler + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _to_tensor(x): + if not isinstance(x, torch.Tensor): + return torch.tensor(x, dtype=torch.float, device=device) + return x + + +def _create_ts_slices(index, seq_len): + """ + create time series slices from pandas index + + Args: + index (pd.MultiIndex): pandas multiindex with order + seq_len (int): sequence length + """ + assert isinstance(index, pd.MultiIndex), "unsupported index type" + assert seq_len > 0, "sequence length should be larger than 0" + assert index.is_monotonic_increasing, "index should be sorted" + + # number of dates for each instrument + sample_count_by_insts = index.to_series().groupby(level=0).size().values + + # start index for each instrument + start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1) + start_index_of_insts[0] = 0 + + # all the [start, stop) indices of features + # features between [start, stop) will be used to predict label at `stop - 1` + slices = [] + for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts): + for stop in range(1, cur_cnt + 1): + end = cur_loc + stop + start = max(end - seq_len, 0) + slices.append(slice(start, end)) + slices = np.array(slices, dtype="object") + + assert len(slices) == len(index) # the i-th slice = index[i] + + return slices + + +def _get_date_parse_fn(target): + """get date parse function + + This method is used to parse date arguments as target type. + + Example: + get_date_parse_fn('20120101')('2017-01-01') => '20170101' + get_date_parse_fn(20120101)('2017-01-01') => 20170101 + """ + if isinstance(target, pd.Timestamp): + _fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01') + elif isinstance(target, int): + _fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201 + elif isinstance(target, str) and len(target) == 8: + _fn = lambda x: str(x).replace("-", "")[:8] # '20200201' + else: + _fn = lambda x: x # '2021-01-01' + return _fn + + +def _maybe_padding(x, seq_len, zeros=None): + """padding 2d