1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 03:50:57 +08:00

fix metric calculation error

This commit is contained in:
Young
2021-09-01 00:24:50 +00:00
committed by you-n-g
parent 5f0ee6ce68
commit 5003e49197
4 changed files with 35 additions and 5 deletions

View File

@@ -528,6 +528,9 @@ class PandasSingleMetric(SingleMetric):
def reindex(self, index, fill_value):
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
def __repr__(self):
return repr(self.metric)
class PandasOrderIndicator(BaseOrderIndicator):
"""
@@ -567,6 +570,9 @@ class PandasOrderIndicator(BaseOrderIndicator):
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.assign(metric, tmp_metric.metric)
def __repr__(self):
return repr(self.data)
class NumpyOrderIndicator(BaseOrderIndicator):
"""
@@ -605,3 +611,6 @@ class NumpyOrderIndicator(BaseOrderIndicator):
for indicator in indicators:
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.data[metric] = tmp_metric
def __repr__(self):
return repr(self.data)

View File

@@ -11,7 +11,7 @@ import pandas as pd
from qlib.backtest.exchange import Exchange
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
from .order import IdxTradeRange
@@ -255,7 +255,7 @@ class Indicator:
# order indicator is metrics for a single order for a specific step
self.order_indicator_his = OrderedDict()
self.order_indicator = self.order_indicator_cls()
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
# trade indicator is metrics for all orders for a specific step
self.trade_indicator_his = OrderedDict()
@@ -265,7 +265,7 @@ class Indicator:
# def reset(self, trade_calendar: TradeCalendarManager):
def reset(self):
self.order_indicator = self.order_indicator_cls()
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
self.trade_indicator = OrderedDict()
# self._trade_calendar = trade_calendar

View File

@@ -280,7 +280,7 @@ class BinaryOps:
self_data_method = getattr(self.obj.data, self.method_name)
if isinstance(other, (int, float, np.number)):
return self.obj.__class__(self_data_method(other))
return self.obj.__class__(self_data_method(other), *self.obj.indices)
elif isinstance(other, self.obj.__class__):
other_aligned = self.obj._align_indices(other)
return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)
@@ -450,6 +450,12 @@ class IndexData(metaclass=index_data_ops_creator):
def isna(self):
return self.__class__(np.isnan(self.data), *self.indices)
def fillna(self, value=0.0, inplace: bool = False):
if inplace:
self.data = np.nan_to_num(self.data, nan=value)
else:
return self.__class__(np.nan_to_num(self.data, nan=value), *self.indices)
def count(self):
return len(self.data[~np.isnan(self.data)])
@@ -507,6 +513,8 @@ class SingleData(IndexData):
----------
new_index : list
new index
fill_value:
what value to fill if index is missing
Returns
-------
@@ -531,7 +539,7 @@ class SingleData(IndexData):
common_index, _ = common_index.sort()
tmp_data1 = self.reindex(common_index, fill_value)
tmp_data2 = other.reindex(common_index, fill_value)
return tmp_data1 + tmp_data2
return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)
def to_dict(self):
"""convert SingleData to dict.

View File

@@ -99,6 +99,19 @@ class IndexDataTest(unittest.TestCase):
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
print(sd1 + sd2)
new_sd = sd2 * 2
self.assertTrue(new_sd.index == sd2.index)
sd1 = idd.SingleData([1, 2, None, 4], index=["foo", "bar", "f", "g"])
sd2 = idd.SingleData([1, 2, 3, None], index=["foo", "bar", "f", "g"])
self.assertTrue(np.isnan((sd1 + sd2).iloc[3]))
self.assertTrue(sd1.add(sd2).sum() == 13)
def test_todo(self):
pass
# here are some examples which do not affect the current system, but it is weird not to support it
# sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
# 2 * sd2
if __name__ == "__main__":