diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index e2707ad39..cc88528fd 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -690,12 +690,14 @@ class Exchange: f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}" ) - def _get_buy_amount_by_cash_limit(self, trade_price, cash): + def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): """return the real order amount after cash limit for buying. Parameters ---------- trade_price : float position : cash + cost_ratio : float + Return ---------- float @@ -704,10 +706,10 @@ class Exchange: max_trade_amount = 0 if cash >= self.min_cost: # critical_price means the stock transaction price when the service fee is equal to min_cost. - critical_price = self.min_cost / self.open_cost + self.min_cost + critical_price = self.min_cost / cost_ratio + self.min_cost if cash >= critical_price: - # the service fee is equal to open_cost * trade_amount - max_trade_amount = cash / (1 + self.open_cost) / trade_price + # the service fee is equal to cost_ratio * trade_amount + max_trade_amount = cash / (1 + cost_ratio) / trade_price else: # the service fee is equal to min_cost max_trade_amount = (cash - self.min_cost) / trade_price @@ -765,9 +767,13 @@ class Exchange: if position is not None: cash = position.get_cash() trade_val = order.deal_amount * trade_price - if cash < trade_val + max(trade_val * cost_ratio, self.min_cost): + if cash < max(trade_val * cost_ratio, self.min_cost): + # cash cannot cover cost + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost): # The money is not enough - max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash) + max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio) order.deal_amount = self.round_amount_by_trade_unit( min(max_buy_amount, order.deal_amount), order.factor )