diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 58f57ed73..8d4739251 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -102,10 +102,11 @@ class Exchange: # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. + self.limit_type = BaseQuote._get_limit_type(limit_threshold) if limit_threshold is None: if C.region == REG_CN: self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") - elif self._get_limit_type(limit_threshold) == self.LT_FLT and abs(limit_threshold) > 0.1: + elif self.limit_type == BaseQuote.LT_FLT and abs(limit_threshold) > 0.1: if C.region == REG_CN: self.logger.warning(f"limit_threshold may not be set to a reasonable value") @@ -127,10 +128,9 @@ class Exchange: # $change is for calculating the limit of the stock necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} - if self._get_limit_type(limit_threshold) == self.LT_TP_EXP: + if self.limit_type == BaseQuote.LT_TP_EXP: for exp in limit_threshold: necessary_fields.add(exp) - subscribe_fields = list(necessary_fields | set(subscribe_fields)) all_fields = list(necessary_fields | set(subscribe_fields)) self.all_fields = all_fields @@ -140,94 +140,22 @@ class Exchange: self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold self.volume_threshold = volume_threshold self.extra_quote = extra_quote - self.set_quote(codes, start_time, end_time) - def set_quote(self, codes, start_time, end_time): - if len(codes) == 0: - codes = D.instruments() - - self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna( - subset=["$close"] + # init quote + self.quote = PandasQuote( + start_time = self.start_time, + end_time = self.end_time, + freq = self.freq, + codes = self.codes, + all_fields = self.all_fields, + limit_threshold = self.limit_threshold, + buy_price = self.buy_price, + sell_price = self.sell_price, + extra_quote = self.extra_quote, ) - self.quote.columns = self.all_fields - - for attr in "buy_price", "sell_price": - pstr = getattr(self, attr) # price string - if self.quote[pstr].isna().any(): - self.logger.warning("{} field data contains nan.".format(pstr)) - - if self.quote["$factor"].isna().any(): - # The 'factor.day.bin' file not exists, and `factor` field contains `nan` - # Use adjusted price - self.trade_w_adj_price = True - self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") - if self.trade_unit is not None: - self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") - - else: - # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` - # Use normal price - self.trade_w_adj_price = False - - # update limit - self._update_limit() - - quote_df = self.quote - if self.extra_quote is not None: - # process extra_quote - if "$close" not in self.extra_quote: - raise ValueError("$close is necessray in extra_quote") - for attr in "buy_price", "sell_price": - pstr = getattr(self, attr) # price string - if pstr not in self.extra_quote.columns: - self.extra_quote[pstr] = self.extra_quote["$close"] - self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") - if "$factor" not in self.extra_quote.columns: - self.extra_quote["$factor"] = 1.0 - self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") - if "limit_sell" not in self.extra_quote.columns: - self.extra_quote["limit_sell"] = False - self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") - if "limit_buy" not in self.extra_quote.columns: - self.extra_quote["limit_buy"] = False - self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") - - assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} - quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) - - quote_dict = {} - for stock_id, stock_val in quote_df.groupby(level="instrument"): - quote_dict[stock_id] = stock_val.droplevel(level="instrument") - - self.quote = quote_dict - - LT_TP_EXP = "(exp)" # Tuple[str, str] - LT_FLT = "float" # float - LT_NONE = "none" # none - - def _get_limit_type(self, limit_threshold): - if isinstance(limit_threshold, Tuple): - return self.LT_TP_EXP - elif isinstance(limit_threshold, float): - return self.LT_FLT - elif limit_threshold is None: - return self.LT_NONE - else: - raise NotImplementedError(f"This type of `limit_threshold` is not supported") - - def _update_limit(self): - # check limit_threshold - lt_type = self._get_limit_type(self.limit_threshold) - if lt_type == self.LT_NONE: - self.quote["limit_buy"] = False - self.quote["limit_sell"] = False - elif lt_type == self.LT_TP_EXP: - # set limit - self.quote["limit_buy"] = self.quote[self.limit_threshold[0]] - self.quote["limit_sell"] = self.quote[self.limit_threshold[1]] - elif lt_type == self.LT_FLT: - self.quote["limit_buy"] = self.quote["$change"].ge(self.limit_threshold) - self.quote["limit_sell"] = self.quote["$change"].le(-self.limit_threshold) # pylint: disable=E1130 + self.trade_w_adj_price = self.quote.get_trade_w_adj_price() + if(self.trade_w_adj_price and (self.trade_unit is not None)): + self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") def check_stock_limit(self, stock_id, start_time, end_time, direction=None): """ @@ -241,20 +169,20 @@ class Exchange: """ if direction is None: - buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") - sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") + buy_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all") + sell_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all") return buy_limit or sell_limit elif direction == Order.BUY: - return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") + return self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all") elif direction == Order.SELL: - return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") + return self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all") else: raise ValueError(f"direction {direction} is not supported!") def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended - if stock_id in self.quote: - return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None + if stock_id in self.quote.get_all_stock(): + return self.quote.get_data(stock_id, start_time, end_time) is None else: return True @@ -313,13 +241,13 @@ class Exchange: return trade_val, trade_cost, trade_price def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last): - return resam_ts_data(self.quote[stock_id], start_time, end_time, method=method) + return self.quote.get_data(stock_id, start_time, end_time, method=method) def get_close(self, stock_id, start_time, end_time, method=ts_data_last): - return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=method) + return self.quote.get_data(stock_id, start_time, end_time, fields="$close", method=method) def get_volume(self, stock_id, start_time, end_time, method="sum"): - return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method=method) + return self.quote.get_data(stock_id, start_time, end_time, fields="$volume", method=method) def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last): if direction == OrderDir.SELL: @@ -328,7 +256,7 @@ class Exchange: pstr = self.buy_price else: raise NotImplementedError(f"This type of input is not supported") - deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=method) + deal_price = self.quote.get_data(stock_id, start_time, end_time, fields=pstr, method=method) if method is not None and (np.isclose(deal_price, 0.0) or np.isnan(deal_price)): self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") @@ -343,9 +271,9 @@ class Exchange: `None`: if the stock is suspended `None` may be returned `float`: return factor if the factor exists """ - if stock_id not in self.quote: + if stock_id not in self.quote.get_all_stock(): return None - return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last) + return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last) def generate_amount_position_from_weight_position( self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY @@ -596,3 +524,145 @@ class Exchange: # cache to avoid recreate the same instance self._order_helper = OrderHelper(self) return self._order_helper + + +class BaseQuote: + + def __init__(self): + self.logger = get_module_logger("online operator", level=logging.INFO) + + def _update_limit(self, limit_threshold): + raise NotImplementedError(f"Please implement the `_update_limit` method") + + def get_trade_w_adj_price(self): + raise NotImplementedError(f"Please implement the `get_trade_w_adj_price` method") + + def get_all_stock(self): + raise NotImplementedError(f"Please implement the `get_all_stock` method") + + def get_data(self, stock_id, start_time, end_time, fields, method): + raise NotImplementedError(f"Please implement the `get_data` method") + + LT_TP_EXP = "(exp)" # Tuple[str, str] + LT_FLT = "float" # float + LT_NONE = "none" # none + + @staticmethod + def _get_limit_type(limit_threshold): + if isinstance(limit_threshold, Tuple): + return BaseQuote.LT_TP_EXP + elif isinstance(limit_threshold, float): + return BaseQuote.LT_FLT + elif limit_threshold is None: + return BaseQuote.LT_NONE + else: + raise NotImplementedError(f"This type of `limit_threshold` is not supported") + + +class PandasQuote(BaseQuote): + + def __init__( + self, + start_time, + end_time, + freq, + codes, + all_fields, + limit_threshold, + buy_price, + sell_price, + extra_quote + ): + + super().__init__() + + # get stock data from qlib + if len(codes) == 0: + codes = D.instruments() + self.data = D.features( + codes, + all_fields, + start_time, + end_time, + freq=freq, + disk_cache=True + ).dropna(subset=["$close"]) + self.data.columns = all_fields + + # check buy_price data and sell_price data + self.buy_price = buy_price + self.sell_price = sell_price + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if self.data[pstr].isna().any(): + self.logger.warning("{} field data contains nan.".format(pstr)) + + # update trade_w_adj_price + if self.data["$factor"].isna().any(): + # The 'factor.day.bin' file not exists, and `factor` field contains `nan` + # Use adjusted price + self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + self.trade_w_adj_price = True + else: + # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` + # Use normal price + self.trade_w_adj_price = False + + # update limit + self._update_limit(limit_threshold) + + # concat extra_quote + quote_df = self.data + if extra_quote is not None: + # process extra_quote + if "$close" not in extra_quote: + raise ValueError("$close is necessray in extra_quote") + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if pstr not in extra_quote.columns: + extra_quote[pstr] = extra_quote["$close"] + self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") + if "$factor" not in extra_quote.columns: + extra_quote["$factor"] = 1.0 + self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") + if "limit_sell" not in extra_quote.columns: + extra_quote["limit_sell"] = False + self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") + if "limit_buy" not in extra_quote.columns: + extra_quote["limit_buy"] = False + self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") + assert set(extra_quote.columns) == set(quote_df.columns) - {"$change"} + quote_df = pd.concat([quote_df, extra_quote], sort=False, axis=0) + + quote_dict = {} + for stock_id, stock_val in quote_df.groupby(level="instrument"): + quote_dict[stock_id] = stock_val.droplevel(level="instrument") + self.data = quote_dict + + def _update_limit(self, limit_threshold): + # check limit_threshold + limit_type = self._get_limit_type(limit_threshold) + if limit_type == self.LT_NONE: + self.data["limit_buy"] = False + self.data["limit_sell"] = False + elif limit_type == self.LT_TP_EXP: + # set limit + self.data["limit_buy"] = self.data[limit_threshold[0]] + self.data["limit_sell"] = self.data[limit_threshold[1]] + elif limit_type == self.LT_FLT: + self.data["limit_buy"] = self.data["$change"].ge(limit_threshold) + self.data["limit_sell"] = self.data["$change"].le(-limit_threshold) # pylint: disable=E1130 + + def get_all_stock(self): + return self.data.keys() + + def get_data(self, stock_id, start_time, end_time, fields = None, method = None): + if(fields is None): + return resam_ts_data(self.data[stock_id], start_time, end_time, method=method) + elif(isinstance(fields, (str, list))): + return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method) + else: + raise ValueError(f"fields must be None, str or list") + + def get_trade_w_adj_price(self): + return self.trade_w_adj_price \ No newline at end of file