diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 0d89dde87..b394d5823 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -64,34 +64,49 @@ class AccumulatedInfo: class Account: def __init__( - self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position" + self, + init_cash: float = 1e9, + freq: str = "day", + benchmark_config: dict = {}, + pos_type: str = "Position", + port_metr_enabled: bool = True, ): - self.pos_type = pos_type + self._pos_type = pos_type + self._port_metr_enabled = port_metr_enabled self.init_vars(init_cash, freq, benchmark_config) + def is_port_metr_enabled(self): + """ + Is portfolio-based metrics enabled. + """ + return self._port_metr_enabled and not self.current.skip_update() + def init_vars(self, init_cash, freq: str, benchmark_config: dict): # init cash self.init_cash = init_cash self.current: BasePosition = init_instance_by_config( { - "class": self.pos_type, + "class": self._pos_type, "kwargs": {"cash": init_cash}, "module_path": "qlib.backtest.position", } ) self.accum_info = AccumulatedInfo() + self.report = None + self.positions = {} self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) def reset_report(self, freq, benchmark_config): # portfolio related metrics - self.report = Report(freq, benchmark_config) - self.positions = {} + if self.is_port_metr_enabled(): + self.report = Report(freq, benchmark_config) + self.positions = {} # trading related matric(e.g. high-frequency trading) self.indicator = Indicator() - def reset(self, freq=None, benchmark_config=None, init_report=False): + def reset(self, freq=None, benchmark_config=None, init_report=False, port_metr_enabled: bool = None): """reset freq and report of account Parameters @@ -108,6 +123,9 @@ class Account: if benchmark_config is not None: self.benchmark_config = benchmark_config + if port_metr_enabled is not None: + self._port_metr_enabled = port_metr_enabled + if freq is not None or benchmark_config is not None or init_report: self.reset_report(self.freq, self.benchmark_config) @@ -137,7 +155,7 @@ class Account: self.accum_info.add_return_value(profit) # note here do not consider cost def update_order(self, order, trade_val, cost, trade_price): - if self.current.skip_update(): + if not self.is_port_metr_enabled(): # TODO: supporting polymorphism for account # updating order for infinite position is meaningless return @@ -160,12 +178,14 @@ class Account: def update_bar_count(self): """at the end of the trading bar, update holding bar, count of stock""" # update holding day count + # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy if not self.current.skip_update(): self.current.add_count_all(bar=self.freq) def update_current(self, trade_start_time, trade_end_time, trade_exchange): """update current to make rtn consistent with earning at the end of bar""" # update price for stock in the position and the profit from changed_price + # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy if not self.current.skip_update(): stock_list = self.current.get_stock_list() for code in stock_list: @@ -227,7 +247,6 @@ class Account: trade_exchange: Exchange, atomic: bool, outer_trade_decision: BaseTradeDecision, - generate_report: bool = False, trade_info: list = None, inner_order_indicators: Indicator = None, indicator_config: dict = {}, @@ -246,8 +265,6 @@ class Account: whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it - if atomic is True, calculate the indicators with trade_info - else, aggregate indicators with inner indicators - generate_report : bool, optional - whether to generate report, by default False trade_info : List[(Order, float, float, float)], optional trading information, by default None - necessary if atomic is True @@ -267,7 +284,7 @@ class Account: # TODO: `update_bar_count` and `update_current` should placed in Position and be merged. self.update_bar_count() self.update_current(trade_start_time, trade_end_time, trade_exchange) - if generate_report: + if self.is_port_metr_enabled(): # report is portfolio related analysis self.update_report(trade_start_time, trade_end_time) @@ -283,3 +300,16 @@ class Account: self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config) self.indicator.record(trade_start_time) + + def get_report(self): + """get the history report and postions instance""" + if self.is_port_metr_enabled(): + _report = self.report.generate_report_dataframe() + _positions = self.get_positions() + return _report, _positions + else: + raise ValueError("generate_report should be True if you want to generate report") + + def get_trade_indicator(self) -> Indicator: + """get the trade indicator instance, which has pa/pos/ffr info.""" + return self.indicator diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 48d06db6c..573c874b0 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -69,13 +69,13 @@ def collect_data_loop( all_executors = trade_executor.get_all_executors() all_reports = { - "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report() + "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_report() for _executor in all_executors - if _executor.generate_report + if _executor.trade_account.is_port_metr_enabled() } all_indicators = {} for _executor in all_executors: key = "{}{}".format(*Freq.parse(_executor.time_per_step)) - all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe() - all_indicators[key + "_obj"] = _executor.get_trade_indicator() + all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator() return_value.update({"report": all_reports, "indicator": all_indicators}) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 14d97e825..adea9dde0 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -103,8 +103,10 @@ class BaseExecutor: self.common_infra.update(common_infra) if common_infra.has("trade_account"): + # NOTE: there is a trick in the code. + # copy is used instead of deepcopy. So positions are shared self.trade_account = copy.copy(common_infra.get("trade_account")) - self.trade_account.reset(freq=self.time_per_step, init_report=True) + self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report) def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs): """ @@ -167,19 +169,6 @@ class BaseExecutor: yield trade_decision return self.execute(trade_decision) - def get_report(self): - """get the history report and postions instance""" - if self.generate_report: - _report = self.trade_account.report.generate_report_dataframe() - _positions = self.trade_account.get_positions() - return _report, _positions - else: - raise ValueError("generate_report should be True if you want to generate report") - - def get_trade_indicator(self) -> Indicator: - """get the trade indicator instance, which has pa/pos/ffr info.""" - return self.trade_account.indicator - def get_all_executors(self): """get all executors""" return [self] @@ -289,21 +278,19 @@ class NestedExecutor(BaseExecutor): _inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision) execute_result.extend(_inner_execute_result) - inner_order_indicators.append(self.inner_executor.get_trade_indicator().get_order_indicator()) + inner_order_indicators.append(self.inner_executor.trade_account.get_trade_indicator().get_order_indicator()) - if hasattr(self, "trade_account"): - trade_step = self.trade_calendar.get_trade_step() - trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - self.trade_account.update_bar_end( - trade_start_time, - trade_end_time, - self.trade_exchange, - atomic=False, - outer_trade_decision=trade_decision, - generate_report=self.generate_report, - inner_order_indicators=inner_order_indicators, - indicator_config=self.indicator_config, - ) + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + self.trade_account.update_bar_end( + trade_start_time, + trade_end_time, + self.trade_exchange, + atomic=False, + outer_trade_decision=trade_decision, + inner_order_indicators=inner_order_indicators, + indicator_config=self.indicator_config, + ) self.trade_calendar.step() if return_value is not None: @@ -457,7 +444,6 @@ class SimulatorExecutor(BaseExecutor): self.trade_exchange, atomic=True, outer_trade_decision=trade_decision, - generate_report=self.generate_report, trade_info=execute_result, indicator_config=self.indicator_config, ) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index c8a326e80..a787c098f 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -10,6 +10,8 @@ from ..utils import init_instance_by_config from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..backtest.order import BaseTradeDecision +__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"] + class BaseStrategy: """Base strategy for trading"""