1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

simplify the portfolio-based report

This commit is contained in:
Young
2021-07-06 11:09:25 +00:00
parent 03d6facbd2
commit 4e41e9c8f2
4 changed files with 62 additions and 44 deletions

View File

@@ -64,34 +64,49 @@ class AccumulatedInfo:
class Account:
def __init__(
self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position"
self,
init_cash: float = 1e9,
freq: str = "day",
benchmark_config: dict = {},
pos_type: str = "Position",
port_metr_enabled: bool = True,
):
self.pos_type = pos_type
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.init_vars(init_cash, freq, benchmark_config)
def is_port_metr_enabled(self):
"""
Is portfolio-based metrics enabled.
"""
return self._port_metr_enabled and not self.current.skip_update()
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
# init cash
self.init_cash = init_cash
self.current: BasePosition = init_instance_by_config(
{
"class": self.pos_type,
"class": self._pos_type,
"kwargs": {"cash": init_cash},
"module_path": "qlib.backtest.position",
}
)
self.accum_info = AccumulatedInfo()
self.report = None
self.positions = {}
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
def reset_report(self, freq, benchmark_config):
# portfolio related metrics
self.report = Report(freq, benchmark_config)
self.positions = {}
if self.is_port_metr_enabled():
self.report = Report(freq, benchmark_config)
self.positions = {}
# trading related matric(e.g. high-frequency trading)
self.indicator = Indicator()
def reset(self, freq=None, benchmark_config=None, init_report=False):
def reset(self, freq=None, benchmark_config=None, init_report=False, port_metr_enabled: bool = None):
"""reset freq and report of account
Parameters
@@ -108,6 +123,9 @@ class Account:
if benchmark_config is not None:
self.benchmark_config = benchmark_config
if port_metr_enabled is not None:
self._port_metr_enabled = port_metr_enabled
if freq is not None or benchmark_config is not None or init_report:
self.reset_report(self.freq, self.benchmark_config)
@@ -137,7 +155,7 @@ class Account:
self.accum_info.add_return_value(profit) # note here do not consider cost
def update_order(self, order, trade_val, cost, trade_price):
if self.current.skip_update():
if not self.is_port_metr_enabled():
# TODO: supporting polymorphism for account
# updating order for infinite position is meaningless
return
@@ -160,12 +178,14 @@ class Account:
def update_bar_count(self):
"""at the end of the trading bar, update holding bar, count of stock"""
# update holding day count
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
if not self.current.skip_update():
self.current.add_count_all(bar=self.freq)
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
"""update current to make rtn consistent with earning at the end of bar"""
# update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
if not self.current.skip_update():
stock_list = self.current.get_stock_list()
for code in stock_list:
@@ -227,7 +247,6 @@ class Account:
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
generate_report: bool = False,
trade_info: list = None,
inner_order_indicators: Indicator = None,
indicator_config: dict = {},
@@ -246,8 +265,6 @@ class Account:
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
- if atomic is True, calculate the indicators with trade_info
- else, aggregate indicators with inner indicators
generate_report : bool, optional
whether to generate report, by default False
trade_info : List[(Order, float, float, float)], optional
trading information, by default None
- necessary if atomic is True
@@ -267,7 +284,7 @@ class Account:
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
self.update_bar_count()
self.update_current(trade_start_time, trade_end_time, trade_exchange)
if generate_report:
if self.is_port_metr_enabled():
# report is portfolio related analysis
self.update_report(trade_start_time, trade_end_time)
@@ -283,3 +300,16 @@ class Account:
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
self.indicator.record(trade_start_time)
def get_report(self):
"""get the history report and postions instance"""
if self.is_port_metr_enabled():
_report = self.report.generate_report_dataframe()
_positions = self.get_positions()
return _report, _positions
else:
raise ValueError("generate_report should be True if you want to generate report")
def get_trade_indicator(self) -> Indicator:
"""get the trade indicator instance, which has pa/pos/ffr info."""
return self.indicator

View File

@@ -69,13 +69,13 @@ def collect_data_loop(
all_executors = trade_executor.get_all_executors()
all_reports = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report()
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_report()
for _executor in all_executors
if _executor.generate_report
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.get_trade_indicator()
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return_value.update({"report": all_reports, "indicator": all_indicators})

View File

@@ -103,8 +103,10 @@ class BaseExecutor:
self.common_infra.update(common_infra)
if common_infra.has("trade_account"):
# NOTE: there is a trick in the code.
# copy is used instead of deepcopy. So positions are shared
self.trade_account = copy.copy(common_infra.get("trade_account"))
self.trade_account.reset(freq=self.time_per_step, init_report=True)
self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report)
def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs):
"""
@@ -167,19 +169,6 @@ class BaseExecutor:
yield trade_decision
return self.execute(trade_decision)
def get_report(self):
"""get the history report and postions instance"""
if self.generate_report:
_report = self.trade_account.report.generate_report_dataframe()
_positions = self.trade_account.get_positions()
return _report, _positions
else:
raise ValueError("generate_report should be True if you want to generate report")
def get_trade_indicator(self) -> Indicator:
"""get the trade indicator instance, which has pa/pos/ffr info."""
return self.trade_account.indicator
def get_all_executors(self):
"""get all executors"""
return [self]
@@ -289,21 +278,19 @@ class NestedExecutor(BaseExecutor):
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
execute_result.extend(_inner_execute_result)
inner_order_indicators.append(self.inner_executor.get_trade_indicator().get_order_indicator())
inner_order_indicators.append(self.inner_executor.trade_account.get_trade_indicator().get_order_indicator())
if hasattr(self, "trade_account"):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
self.trade_account.update_bar_end(
trade_start_time,
trade_end_time,
self.trade_exchange,
atomic=False,
outer_trade_decision=trade_decision,
generate_report=self.generate_report,
inner_order_indicators=inner_order_indicators,
indicator_config=self.indicator_config,
)
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
self.trade_account.update_bar_end(
trade_start_time,
trade_end_time,
self.trade_exchange,
atomic=False,
outer_trade_decision=trade_decision,
inner_order_indicators=inner_order_indicators,
indicator_config=self.indicator_config,
)
self.trade_calendar.step()
if return_value is not None:
@@ -457,7 +444,6 @@ class SimulatorExecutor(BaseExecutor):
self.trade_exchange,
atomic=True,
outer_trade_decision=trade_decision,
generate_report=self.generate_report,
trade_info=execute_result,
indicator_config=self.indicator_config,
)

View File

@@ -10,6 +10,8 @@ from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.order import BaseTradeDecision
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
class BaseStrategy:
"""Base strategy for trading"""