mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
fix metric calculation error
This commit is contained in:
@@ -528,6 +528,9 @@ class PandasSingleMetric(SingleMetric):
|
||||
def reindex(self, index, fill_value):
|
||||
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.metric)
|
||||
|
||||
|
||||
class PandasOrderIndicator(BaseOrderIndicator):
|
||||
"""
|
||||
@@ -567,6 +570,9 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
order_indicator.assign(metric, tmp_metric.metric)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.data)
|
||||
|
||||
|
||||
class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
"""
|
||||
@@ -605,3 +611,6 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
for indicator in indicators:
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.data)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from .order import IdxTradeRange
|
||||
@@ -255,7 +255,7 @@ class Indicator:
|
||||
|
||||
# order indicator is metrics for a single order for a specific step
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator = self.order_indicator_cls()
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
|
||||
# trade indicator is metrics for all orders for a specific step
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
@@ -265,7 +265,7 @@ class Indicator:
|
||||
|
||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||
def reset(self):
|
||||
self.order_indicator = self.order_indicator_cls()
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
self.trade_indicator = OrderedDict()
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class BinaryOps:
|
||||
self_data_method = getattr(self.obj.data, self.method_name)
|
||||
|
||||
if isinstance(other, (int, float, np.number)):
|
||||
return self.obj.__class__(self_data_method(other))
|
||||
return self.obj.__class__(self_data_method(other), *self.obj.indices)
|
||||
elif isinstance(other, self.obj.__class__):
|
||||
other_aligned = self.obj._align_indices(other)
|
||||
return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)
|
||||
@@ -450,6 +450,12 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
def isna(self):
|
||||
return self.__class__(np.isnan(self.data), *self.indices)
|
||||
|
||||
def fillna(self, value=0.0, inplace: bool = False):
|
||||
if inplace:
|
||||
self.data = np.nan_to_num(self.data, nan=value)
|
||||
else:
|
||||
return self.__class__(np.nan_to_num(self.data, nan=value), *self.indices)
|
||||
|
||||
def count(self):
|
||||
return len(self.data[~np.isnan(self.data)])
|
||||
|
||||
@@ -507,6 +513,8 @@ class SingleData(IndexData):
|
||||
----------
|
||||
new_index : list
|
||||
new index
|
||||
fill_value:
|
||||
what value to fill if index is missing
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -531,7 +539,7 @@ class SingleData(IndexData):
|
||||
common_index, _ = common_index.sort()
|
||||
tmp_data1 = self.reindex(common_index, fill_value)
|
||||
tmp_data2 = other.reindex(common_index, fill_value)
|
||||
return tmp_data1 + tmp_data2
|
||||
return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)
|
||||
|
||||
def to_dict(self):
|
||||
"""convert SingleData to dict.
|
||||
|
||||
@@ -99,6 +99,19 @@ class IndexDataTest(unittest.TestCase):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
print(sd1 + sd2)
|
||||
new_sd = sd2 * 2
|
||||
self.assertTrue(new_sd.index == sd2.index)
|
||||
|
||||
sd1 = idd.SingleData([1, 2, None, 4], index=["foo", "bar", "f", "g"])
|
||||
sd2 = idd.SingleData([1, 2, 3, None], index=["foo", "bar", "f", "g"])
|
||||
self.assertTrue(np.isnan((sd1 + sd2).iloc[3]))
|
||||
self.assertTrue(sd1.add(sd2).sum() == 13)
|
||||
|
||||
def test_todo(self):
|
||||
pass
|
||||
# here are some examples which do not affect the current system, but it is weird not to support it
|
||||
# sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
# 2 * sd2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user