diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 104be5b9c..d556f303c 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -172,15 +172,25 @@ class BaseSingleMetric: @property def empty(self) -> bool: """If metric is empyt, return True.""" + raise NotImplementedError(f"Please implement the `empty` method") def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric": """Replace np.NaN with fill_value in two metrics and add them.""" + raise NotImplementedError(f"Please implement the `add` method") - def map(self, map_dict: dict) -> "BaseSingleMetric": - """Replace the value of metric according to map_dict.""" - raise NotImplementedError(f"Please implement the `map` method") + def replace(self, replace_dict: dict) -> "BaseSingleMetric": + """Replace the value of metric according to replace_dict.""" + + raise NotImplementedError(f"Please implement the `replace` method") + + def apply(self, func: dict) -> "BaseSingleMetric": + """Replace the value of metric with func(metric). + Currently, the func is only qlib/backtest/order/Order.parse_dir. + """ + + raise NotImplementedError(f"Please implement the 'apply' method") class BaseOrderIndicator: @@ -371,8 +381,11 @@ class PandasSingleMetric: def add(self, other, fill_value=None): return PandasSingleMetric(self.metric.add(other.metric, fill_value=fill_value)) - def map(self, map_dict: dict): - return PandasSingleMetric(self.metric.apply(map_dict)) + def replace(self, replace_dict: dict): + return PandasSingleMetric(self.metric.replace(replace_dict)) + + def apply(self, func: Callable): + return PandasSingleMetric(self.metric.apply(func)) class PandasOrderIndicator(BaseOrderIndicator): @@ -413,6 +426,11 @@ class PandasOrderIndicator(BaseOrderIndicator): for metric in metrics: tmp_metric = PandasSingleMetric({}) for indicator in indicators: - tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) + if(metric == "trade_price"): + tmp_metric = tmp_metric.add( + indicator.data["trade_price"] * indicator.data["deal_amount"], fill_value + ) + else: + tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) metric_dict[metric] = tmp_metric.metric return metric_dict diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 95048ba84..64d00b436 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -308,7 +308,8 @@ class Indicator: def _update_order_fulfill_rate(self): def func(deal_amount, amount): - return deal_amount / amount + tmp_deal_amount = deal_amount.replace({np.NaN: 0}) + return deal_amount / tmp_deal_amount self.order_indicator.transfer(func, "ffr") @@ -323,12 +324,13 @@ class Indicator: self.order_indicator.assign(metric, metric_dict[metric]) def func(trade_price, deal_amount): - return trade_price / deal_amount + tmp_deal_amount = deal_amount.replace({0: np.NaN}) + return trade_price / tmp_deal_amount self.order_indicator.transfer(func, "trade_price") def func_apply(trade_dir): - return trade_dir.map(Order.parse_dir) + return trade_dir.apply(Order.parse_dir) self.order_indicator.transfer(func_apply, "trade_dir")