diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 19dbe87ce..dbfbd4a0e 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -8,9 +8,9 @@ from .account import Account if TYPE_CHECKING: from ..strategy.base import BaseStrategy + from .executor import BaseExecutor from .position import Position from .exchange import Exchange -from .executor import BaseExecutor from .backtest import backtest_loop from .backtest import collect_data_loop from .order import Order @@ -155,6 +155,7 @@ def get_strategy_executor( # - for avoiding recursive import # - typing annotations is not reliable from ..strategy.base import BaseStrategy + from .executor import BaseExecutor trade_account = create_account_instance( start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 806f88a96..9b9a25c23 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -75,17 +75,7 @@ class Account: ): self._pos_type = pos_type self._port_metr_enabled = port_metr_enabled - self.init_vars(init_cash, position_dict, freq, benchmark_config) - def is_port_metr_enabled(self): - """ - Is portfolio-based metrics enabled. - """ - return self._port_metr_enabled and not self.current.skip_update() - - def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict): - - # init cash self.init_cash = init_cash self.current: BasePosition = init_instance_by_config( { @@ -100,8 +90,19 @@ class Account: self.accum_info = AccumulatedInfo() self.report = None self.positions = {} + + # in of reset ignore None values + self.benchmark_config = benchmark_config + self.freq = freq + self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) + def is_port_metr_enabled(self): + """ + Is portfolio-based metrics enabled. + """ + return self._port_metr_enabled and not self.current.skip_update() + def reset_report(self, freq, benchmark_config): # portfolio related metrics if self.is_port_metr_enabled(): diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index a22754885..ea1d012eb 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -512,7 +512,7 @@ class Exchange: def _get_factor_or_raise_erorr(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None): """Please refer to the docs of get_amount_of_trade_unit""" if factor is None: - if stock_id is not None and start_time is not None and end_time is not None : + if stock_id is not None and start_time is not None and end_time is not None: factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time) else: raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") @@ -537,15 +537,16 @@ class Exchange: the end time of trading range """ if not self.trade_w_adj_price and self.trade_unit is not None: - factor = self._get_factor_or_raise_erorr(factor=factor, - stock_id=stock_id, - start_time=start_time, - end_time=end_time) + factor = self._get_factor_or_raise_erorr( + factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + ) return self.trade_unit / factor else: return None - def round_amount_by_trade_unit(self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None): + def round_amount_by_trade_unit( + self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None + ): """Parameter Please refer to the docs of get_amount_of_trade_unit @@ -555,10 +556,9 @@ class Exchange: """ if not self.trade_w_adj_price and self.trade_unit is not None: # the minimal amount is 1. Add 0.1 for solving precision problem. - factor = self._get_factor_or_raise_erorr(factor=factor, - stock_id=stock_id, - start_time=start_time, - end_time=end_time) + factor = self._get_factor_or_raise_erorr( + factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + ) return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 6b64bf3b1..84cae2568 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -80,11 +80,12 @@ class Report: def init_bench(self, freq=None, benchmark_config=None): if freq is not None: self.freq = freq - if benchmark_config is not None: - self.benchmark_config = benchmark_config + self.benchmark_config = benchmark_config self.bench = self._cal_benchmark(self.benchmark_config, self.freq) def _cal_benchmark(self, benchmark_config, freq): + if benchmark_config is None: + return None benchmark = benchmark_config.get("benchmark", CSI300_BENCH) if benchmark is None: return None diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 1ec054e45..b42c4f578 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -63,9 +63,9 @@ class TWAPStrategy(BaseStrategy): stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): continue - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id, - start_time=order.start_time, - end_time=order.end_time) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) _order_amount = None # considering trade unit if _amount_trade_unit is None: @@ -169,9 +169,9 @@ class SBBStrategyBase(BaseStrategy): self.trade_trend[order.stock_id] = _pred_trend continue # get amount of one trade unit - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id, - start_time=order.start_time, - end_time=order.end_time) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if _pred_trend == self.TREND_MID: _order_amount = None # considering trade unit @@ -471,9 +471,9 @@ class ACStrategy(BaseStrategy): if sig_sam is None or np.isnan(sig_sam): # no signal, TWAP - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id, - start_time=order.start_time, - end_time=order.end_time) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if _amount_trade_unit is None: # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) @@ -494,10 +494,9 @@ class ACStrategy(BaseStrategy): np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1)) ) / np.sinh(kappa * trade_len) _order_amount = order.amount * amount_ratio - _order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, - stock_id=order.stock_id, - start_time=order.start_time, - end_time=order.end_time) + _order_amount = self.trade_exchange.round_amount_by_trade_unit( + _order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if order.direction == order.SELL: # sell all amount at last @@ -584,8 +583,11 @@ class FileOrderStrategy(BaseStrategy): """ def __init__( - self, file: Union[IO, str, Path, pd.DataFrame], - trade_range: Union[Tuple[int, int], TradeRange] = None, *args, **kwargs + self, + file: Union[IO, str, Path, pd.DataFrame], + trade_range: Union[Tuple[int, int], TradeRange] = None, + *args, + **kwargs, ): """ diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 7a267b511..c47d2494f 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -2,9 +2,10 @@ # Licensed under the MIT License. from __future__ import annotations from typing import TYPE_CHECKING + if TYPE_CHECKING: from qlib.backtest.exchange import Exchange -from qlib.backtest.position import BasePosition + from qlib.backtest.position import BasePosition from typing import List, Tuple, Union from ..model.base import BaseModel