From 571d27cba7949c65efdfa6b5f48fee8a9c1759e5 Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 14 Jul 2021 13:05:36 +0000 Subject: [PATCH] exchange support expression buy sell limit --- qlib/backtest/exchange.py | 63 +++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 3794651dc..58f57ed73 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -26,7 +26,7 @@ class Exchange: codes="all", deal_price: Union[str, Tuple[str], List[str]] = None, subscribe_fields=[], - limit_threshold=None, + limit_threshold: Union[Tuple[str, str], float, None] = None, volume_threshold=None, open_cost=0.0015, close_cost=0.0025, @@ -41,7 +41,7 @@ class Exchange: :param end_time: closed end time for backtest :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) - :param deal_price: Union[str, Tuple[str], List[str]] + :param deal_price: Union[str, Tuple[str, str], List[str]] The `deal_price` supports following two types of input - : str - (, ): Tuple[str] or List[str] @@ -51,8 +51,16 @@ class Exchange: - for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend "$" to the expression) - :param subscribe_fields: list, subscribe fields - :param limit_threshold: float, 0.1 for example, default None + :param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`. + It is useful when users want more fields to be queried + + :param limit_threshold: Union[Tuple[str, str], float, None] + 1) `None`: no limitation + 2) float, 0.1 for example, default None + 3) Tuple[str, str]: (, + ) + `False` value indicates the stock is tradable + `True` value indicates the stock is limited and not tradable :param volume_threshold: float, 0.1 for example, default None :param open_cost: cost rate for open, default 0.0015 :param close_cost: cost rate for close, default 0.0025 @@ -97,7 +105,7 @@ class Exchange: if limit_threshold is None: if C.region == REG_CN: self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") - elif abs(limit_threshold) > 0.1: + elif self._get_limit_type(limit_threshold) == self.LT_FLT and abs(limit_threshold) > 0.1: if C.region == REG_CN: self.logger.warning(f"limit_threshold may not be set to a reasonable value") @@ -119,13 +127,17 @@ class Exchange: # $change is for calculating the limit of the stock necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} + if self._get_limit_type(limit_threshold) == self.LT_TP_EXP: + for exp in limit_threshold: + necessary_fields.add(exp) subscribe_fields = list(necessary_fields | set(subscribe_fields)) all_fields = list(necessary_fields | set(subscribe_fields)) + self.all_fields = all_fields self.open_cost = open_cost self.close_cost = close_cost self.min_cost = min_cost - self.limit_threshold = limit_threshold + self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold self.volume_threshold = volume_threshold self.extra_quote = extra_quote self.set_quote(codes, start_time, end_time) @@ -133,6 +145,7 @@ class Exchange: def set_quote(self, codes, start_time, end_time): if len(codes) == 0: codes = D.instruments() + self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna( subset=["$close"] ) @@ -157,13 +170,7 @@ class Exchange: self.trade_w_adj_price = False # update limit - # check limit_threshold - if self.limit_threshold is None: - self.quote["limit_buy"] = False - self.quote["limit_sell"] = False - else: - # set limit - self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold) + self._update_limit() quote_df = self.quote if self.extra_quote is not None: @@ -194,9 +201,33 @@ class Exchange: self.quote = quote_dict - def _update_limit(self, buy_limit, sell_limit): - self.quote["limit_buy"] = self.quote["$change"].ge(buy_limit) - self.quote["limit_sell"] = self.quote["$change"].le(-sell_limit) + LT_TP_EXP = "(exp)" # Tuple[str, str] + LT_FLT = "float" # float + LT_NONE = "none" # none + + def _get_limit_type(self, limit_threshold): + if isinstance(limit_threshold, Tuple): + return self.LT_TP_EXP + elif isinstance(limit_threshold, float): + return self.LT_FLT + elif limit_threshold is None: + return self.LT_NONE + else: + raise NotImplementedError(f"This type of `limit_threshold` is not supported") + + def _update_limit(self): + # check limit_threshold + lt_type = self._get_limit_type(self.limit_threshold) + if lt_type == self.LT_NONE: + self.quote["limit_buy"] = False + self.quote["limit_sell"] = False + elif lt_type == self.LT_TP_EXP: + # set limit + self.quote["limit_buy"] = self.quote[self.limit_threshold[0]] + self.quote["limit_sell"] = self.quote[self.limit_threshold[1]] + elif lt_type == self.LT_FLT: + self.quote["limit_buy"] = self.quote["$change"].ge(self.limit_threshold) + self.quote["limit_sell"] = self.quote["$change"].le(-self.limit_threshold) # pylint: disable=E1130 def check_stock_limit(self, stock_id, start_time, end_time, direction=None): """