diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 3da1ebfdc..4fc01d8e2 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -246,8 +246,8 @@ class Exchange: order, trade_account.current if trade_account else position ) # update account - if trade_val > 0: - # If the order can only be deal 0 trade_val. Nothing to be updated + if order.deal_amount > 1e-5: + # If the order can only be deal 0 aomount. Nothing to be updated # Otherwise, it will result some stock with 0 amount in the position if trade_account: trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) @@ -454,30 +454,21 @@ class Exchange: trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) if order.direction == Order.SELL: # sell - current_amount = position.get_stock_amount(order.stock_id) - if position is not None: + current_amount = ( + position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0 + ) if np.isclose(order.amount, current_amount): # when selling last stock. The amount don't need rounding order.deal_amount = order.amount elif order.amount > current_amount: - self.logger.warning( - f"order amount {order.amount} is greater than current amount {current_amount}, {current_amount} amount of stock is dealed" - ) order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor) else: order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) else: # TODO: We don't know current position. # We choose to sell all - - if not np.isclose(order.amount, current_amount) and order.amount > current_amount: - self.logger.warning( - f"order amount {order.amount} is greater than current amount {current_amount}, {current_amount} amount of stock is dealed" - ) - order.deal_amount = current_amount - else: - order.deal_amount = order.amount + order.deal_amount = order.amount order.deal_amount = self._get_amount_by_volume( order.stock_id, order.start_time, order.end_time, order.deal_amount diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index c6368606a..92b549063 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -73,6 +73,9 @@ class Position: def del_stock(self, stock_id): del self.position[stock_id] + def check_stock(self, stock_id): + return stock_id in self.position + def update_order(self, order, trade_val, cost, trade_price): # handle order, order is a order class, defined in exchange.py if order.direction == Order.BUY: