diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 89f5a2c4a..0121a904e 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -405,7 +405,7 @@ class NestedExecutor(BaseExecutor): execute_result.extend(_inner_execute_result) inner_order_indicators.append( - self.inner_executor.trade_account.get_trade_indicator().get_order_indicator() + self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True) ) else: # do nothing and just step forward diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 123725832..c60d3f97e 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -3,7 +3,7 @@ import logging -from typing import List, Tuple, Union, Callable, Iterable, Dict +from typing import List, Text, Tuple, Union, Callable, Iterable, Dict from collections import OrderedDict import inspect @@ -280,6 +280,21 @@ class BaseOrderIndicator: pass + def to_series(self) -> Dict[Text, pd.Series]: + """return the metrics as pandas series + + for example: { "ffr": + SH600068 NaN + SH600079 1.0 + SH600266 NaN + ... + SZ300692 NaN + SZ300719 NaN, + ... + } + """ + raise NotImplementedError(f"Please implement the `to_series` method") + class PandasSingleMetric: """Each SingleMetric is based on pd.Series.""" @@ -429,3 +444,6 @@ class PandasOrderIndicator(BaseOrderIndicator): tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) metric_dict[metric] = tmp_metric.metric return metric_dict + + def to_series(self): + return {k: v.metric for k, v in self.data.items()} diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index e37642244..fb1eeedfa 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -274,8 +274,8 @@ class Indicator: # self._trade_calendar = trade_calendar def record(self, trade_start_time): - self.order_indicator_his[trade_start_time] = self.order_indicator.data - self.trade_indicator_his[trade_start_time] = self.trade_indicator + self.order_indicator_his[trade_start_time] = self.get_order_indicator() + self.trade_indicator_his[trade_start_time] = self.get_trade_indicator() def _update_order_trade_info(self, trade_info: list): amount = dict() @@ -587,8 +587,10 @@ class Indicator: ) ) - def get_order_indicator(self): - return self.order_indicator + def get_order_indicator(self, raw: bool = False): + if raw: + return self.order_indicator + return self.order_indicator.to_series() def get_trade_indicator(self): return self.trade_indicator diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 57ca005ff..eabbe357b 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -63,11 +63,11 @@ class TWAPStrategy(BaseStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] for order in self.outer_trade_decision.get_decision(): - # if not tradable, continue - if not self.trade_exchange.is_stock_tradable( - stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time - ): - continue + # Don't peek the future information + # if not self.trade_exchange.is_stock_tradable( + # stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + # ): + # continue _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time )