diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index a614f08b6..b2e307fe5 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -167,3 +167,4 @@ class Account: def save_account(self, account_path): self.current.save_position(account_path / "position.xlsx", self.last_trade_date) self.report.save_report(account_path / "report.csv") + diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index cbb3d7932..369b3aef3 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -93,7 +93,7 @@ class Exchange: self.limit_threshold = limit_threshold # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. - if trade_dates is not None and trade_dates: + if trade_dates is not None and len(trade_dates): start_date, end_date = trade_dates[0], trade_dates[-1] else: self.logger.warning("trade_dates have not been assigned, all dates will be loaded") @@ -325,7 +325,7 @@ class Exchange: deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) if deal_amount == 0: continue - if deal_amount > 0: + elif deal_amount > 0: # buy stock buy_order_list.append( Order( @@ -423,3 +423,4 @@ class Exchange: raise NotImplementedError("order type {} error".format(order.type)) return trade_val, trade_cost +