mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
simplify the portfolio-based report
This commit is contained in:
@@ -64,34 +64,49 @@ class AccumulatedInfo:
|
||||
|
||||
class Account:
|
||||
def __init__(
|
||||
self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position"
|
||||
self,
|
||||
init_cash: float = 1e9,
|
||||
freq: str = "day",
|
||||
benchmark_config: dict = {},
|
||||
pos_type: str = "Position",
|
||||
port_metr_enabled: bool = True,
|
||||
):
|
||||
self.pos_type = pos_type
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.init_vars(init_cash, freq, benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
|
||||
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
|
||||
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current: BasePosition = init_instance_by_config(
|
||||
{
|
||||
"class": self.pos_type,
|
||||
"class": self._pos_type,
|
||||
"kwargs": {"cash": init_cash},
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = None
|
||||
self.positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
if self.is_port_metr_enabled():
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
|
||||
# trading related matric(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False):
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False, port_metr_enabled: bool = None):
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -108,6 +123,9 @@ class Account:
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
|
||||
if port_metr_enabled is not None:
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
|
||||
if freq is not None or benchmark_config is not None or init_report:
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
@@ -137,7 +155,7 @@ class Account:
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
if self.current.skip_update():
|
||||
if not self.is_port_metr_enabled():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
@@ -160,12 +178,14 @@ class Account:
|
||||
def update_bar_count(self):
|
||||
"""at the end of the trading bar, update holding bar, count of stock"""
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current.skip_update():
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current.skip_update():
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
@@ -227,7 +247,6 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
generate_report: bool = False,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: Indicator = None,
|
||||
indicator_config: dict = {},
|
||||
@@ -246,8 +265,6 @@ class Account:
|
||||
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
|
||||
- if atomic is True, calculate the indicators with trade_info
|
||||
- else, aggregate indicators with inner indicators
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
trade_info : List[(Order, float, float, float)], optional
|
||||
trading information, by default None
|
||||
- necessary if atomic is True
|
||||
@@ -267,7 +284,7 @@ class Account:
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
if generate_report:
|
||||
if self.is_port_metr_enabled():
|
||||
# report is portfolio related analysis
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
@@ -283,3 +300,16 @@ class Account:
|
||||
|
||||
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
|
||||
self.indicator.record(trade_start_time)
|
||||
|
||||
def get_report(self):
|
||||
"""get the history report and postions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
_report = self.report.generate_report_dataframe()
|
||||
_positions = self.get_positions()
|
||||
return _report, _positions
|
||||
else:
|
||||
raise ValueError("generate_report should be True if you want to generate report")
|
||||
|
||||
def get_trade_indicator(self) -> Indicator:
|
||||
"""get the trade indicator instance, which has pa/pos/ffr info."""
|
||||
return self.indicator
|
||||
|
||||
@@ -69,13 +69,13 @@ def collect_data_loop(
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
all_reports = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report()
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_report()
|
||||
for _executor in all_executors
|
||||
if _executor.generate_report
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.get_trade_indicator()
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
return_value.update({"report": all_reports, "indicator": all_indicators})
|
||||
|
||||
@@ -103,8 +103,10 @@ class BaseExecutor:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# copy is used instead of deepcopy. So positions are shared
|
||||
self.trade_account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True)
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report)
|
||||
|
||||
def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
"""
|
||||
@@ -167,19 +169,6 @@ class BaseExecutor:
|
||||
yield trade_decision
|
||||
return self.execute(trade_decision)
|
||||
|
||||
def get_report(self):
|
||||
"""get the history report and postions instance"""
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
return _report, _positions
|
||||
else:
|
||||
raise ValueError("generate_report should be True if you want to generate report")
|
||||
|
||||
def get_trade_indicator(self) -> Indicator:
|
||||
"""get the trade indicator instance, which has pa/pos/ffr info."""
|
||||
return self.trade_account.indicator
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
@@ -289,21 +278,19 @@ class NestedExecutor(BaseExecutor):
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
|
||||
|
||||
execute_result.extend(_inner_execute_result)
|
||||
inner_order_indicators.append(self.inner_executor.get_trade_indicator().get_order_indicator())
|
||||
inner_order_indicators.append(self.inner_executor.trade_account.get_trade_indicator().get_order_indicator())
|
||||
|
||||
if hasattr(self, "trade_account"):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=False,
|
||||
outer_trade_decision=trade_decision,
|
||||
generate_report=self.generate_report,
|
||||
inner_order_indicators=inner_order_indicators,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=False,
|
||||
outer_trade_decision=trade_decision,
|
||||
inner_order_indicators=inner_order_indicators,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
|
||||
self.trade_calendar.step()
|
||||
if return_value is not None:
|
||||
@@ -457,7 +444,6 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self.trade_exchange,
|
||||
atomic=True,
|
||||
outer_trade_decision=trade_decision,
|
||||
generate_report=self.generate_report,
|
||||
trade_info=execute_result,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,8 @@ from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.order import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
"""Base strategy for trading"""
|
||||
|
||||
Reference in New Issue
Block a user