From 8aee853a1145effe8dd9cf5835319a0ee090d7da Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 4 Jun 2021 00:55:10 +0800 Subject: [PATCH] update Exchange --- qlib/backtest/exchange.py | 93 ++++++++++++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 20 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index de2df98be..3da1ebfdc 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -26,6 +26,7 @@ class Exchange: deal_price=None, subscribe_fields=[], limit_threshold=None, + volume_threshold=None, open_cost=0.0015, close_cost=0.0025, trade_unit=None, @@ -41,6 +42,7 @@ class Exchange: :param deal_price: str, 'close', 'open', 'vwap' :param subscribe_fields: list, subscribe fields :param limit_threshold: float, 0.1 for example, default None + :param volume_threshold: float, 0.1 for example, default None :param open_cost: cost rate for open, default 0.0015 :param close_cost: cost rate for close, default 0.0025 :param trade_unit: trade unit, 100 for China A market @@ -60,6 +62,7 @@ class Exchange: self.freq = freq self.start_time = start_time self.end_time = end_time + if trade_unit is None: trade_unit = C.trade_unit if limit_threshold is None: @@ -70,7 +73,6 @@ class Exchange: self.logger = get_module_logger("online operator", level=logging.INFO) self.trade_unit = trade_unit - # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. if limit_threshold is None: @@ -100,7 +102,7 @@ class Exchange: self.close_cost = close_cost self.min_cost = min_cost self.limit_threshold = limit_threshold - + self.volume_threshold = volume_threshold self.extra_quote = extra_quote self.set_quote(codes, start_time, end_time) @@ -120,14 +122,19 @@ class Exchange: # Use adjusted price self.trade_w_adj_price = True self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + if self.trade_unit is not None: + self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") + else: # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` # Use normal price self.trade_w_adj_price = False + # update limit # check limit_threshold if self.limit_threshold is None: - self.quote["limit"] = False + self.quote["limit_buy"] = False + self.quote["limit_sell"] = False else: # set limit self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold) @@ -143,9 +150,13 @@ class Exchange: if "$factor" not in self.extra_quote.columns: self.extra_quote["$factor"] = 1.0 self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") - if "limit" not in self.extra_quote.columns: - self.extra_quote["limit"] = False - self.logger.warning("No limit set for extra_quote. All stock will be tradable.") + if "limit_sell" not in self.extra_quote.columns: + self.extra_quote["limit_sell"] = False + self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") + if "limit_buy" not in self.extra_quote.columns: + self.extra_quote["limit_buy"] = False + self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") + assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) @@ -160,15 +171,30 @@ class Exchange: self.quote = quote_dict def _update_limit(self, buy_limit, sell_limit): - self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False) + self.quote["limit_buy"] = ~self.quote["$change"].lt(buy_limit) + self.quote["limit_sell"] = ~self.quote["$change"].gt(-sell_limit) - def check_stock_limit(self, stock_id, start_time, end_time): - """Parameter - stock_id - trade_date - is limtited + def check_stock_limit(self, stock_id, start_time, end_time, direction=None): """ - return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0] + Parameters + ---------- + direction : int, optional + trade direction, by default None + - if direction is None, check if tradable for buying and selling. + - if direction == Order.BUY, check the if tradable for buying + - if direction == Order.SELL, check the sell limit for selling. + + """ + if direction is None: + buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + return buy_limit or sell_limit + elif direction == Order.BUY: + return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + elif direction == Order.SELL: + return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + else: + raise ValueError(f"direction {direction} is not supported!") def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended @@ -177,11 +203,11 @@ class Exchange: else: return True - def is_stock_tradable(self, stock_id, start_time, end_time): + def is_stock_tradable(self, stock_id, start_time, end_time, direction=None): # check if stock can be traded # same as check in check_order if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit( - stock_id, start_time, end_time + stock_id, start_time, end_time, direction ): return False else: @@ -190,7 +216,7 @@ class Exchange: def check_order(self, order): # check limit and suspended if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit( - order.stock_id, order.start_time, order.end_time + order.stock_id, order.start_time, order.end_time, order.direction ): return False else: @@ -393,7 +419,7 @@ class Exchange: return value def get_amount_of_trade_unit(self, factor): - if not self.trade_w_adj_price: + if not self.trade_w_adj_price and self.trade_unit is not None: return self.trade_unit / factor else: return None @@ -404,11 +430,18 @@ class Exchange: factor : float, adjusted factor return : float, real amount """ - if not self.trade_w_adj_price: + if not self.trade_w_adj_price and self.trade_unit is not None: # the minimal amount is 1. Add 0.1 for solving precision problem. return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount + def _get_amount_by_volume(self, stock_id, trade_start_time, trade_end_time, deal_amount): + if self.volume_threshold is not None: + tradable_amount = self.get_volume(stock_id, trade_start_time, trade_end_time) * self.volume_threshold + return max(min(tradable_amount, deal_amount), 0) + else: + return deal_amount + def _calc_trade_info_by_order(self, order, position): """ Calculation of trade info @@ -421,17 +454,34 @@ class Exchange: trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) if order.direction == Order.SELL: # sell + current_amount = position.get_stock_amount(order.stock_id) + if position is not None: - if np.isclose(order.amount, position.get_stock_amount(order.stock_id)): + if np.isclose(order.amount, current_amount): # when selling last stock. The amount don't need rounding order.deal_amount = order.amount + elif order.amount > current_amount: + self.logger.warning( + f"order amount {order.amount} is greater than current amount {current_amount}, {current_amount} amount of stock is dealed" + ) + order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor) else: order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) else: # TODO: We don't know current position. # We choose to sell all - order.deal_amount = order.amount + if not np.isclose(order.amount, current_amount) and order.amount > current_amount: + self.logger.warning( + f"order amount {order.amount} is greater than current amount {current_amount}, {current_amount} amount of stock is dealed" + ) + order.deal_amount = current_amount + else: + order.deal_amount = order.amount + + order.deal_amount = self._get_amount_by_volume( + order.stock_id, order.start_time, order.end_time, order.deal_amount + ) trade_val = order.deal_amount * trade_price trade_cost = max(trade_val * self.close_cost, self.min_cost) elif order.direction == Order.BUY: @@ -451,6 +501,9 @@ class Exchange: # Unknown amount of money. Just round the amount order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + order.deal_amount = self._get_amount_by_volume( + order.stock_id, order.start_time, order.end_time, order.deal_amount + ) trade_val = order.deal_amount * trade_price trade_cost = trade_val * self.open_cost else: