diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 74927e2be..bb75ca8f6 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -167,11 +167,14 @@ class CN1minNumpyQuote(BaseQuote): def _agg_data(self, data: IndexData, method): """Agg data by specific method.""" + # FIXME: why not call the method of data directly? if method == "sum": return np.nansum(data) elif method == "mean": return np.nanmean(data) elif method == "last": + # FIXME: I've never seen that this method was called. + # Please merge it with "ts_data_last" return data[-1] elif method == "all": return data.all() diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 505f0dd33..c8d6bebee 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -10,7 +10,7 @@ Motivation of index_data """ from functools import partial -from typing import Tuple, Union, Callable, List +from typing import Dict, Tuple, Union, Callable, List import bisect import numpy as np @@ -128,8 +128,7 @@ class Index: raise KeyError(f"{item} can't be found in {self}") def __or__(self, other: "Index"): - idx = Index(idx_list=list(set(self.idx_list) | set(other.idx_list))) - return idx + return Index(idx_list=list(set(self.idx_list) | set(other.idx_list))) def __eq__(self, other: "Index"): # NOTE: np.nan is not supported in the index @@ -283,9 +282,8 @@ class BinaryOps: if isinstance(other, (int, float, np.number)): return self.obj.__class__(self_data_method(other)) elif isinstance(other, self.obj.__class__): - # TODO: bad interface - tmp_data1, tmp_data2 = self.obj._align_indices(other) - return self.obj.__class__(self_data_method(tmp_data2.data), *self.obj.indices) + other_aligned = self.obj._align_indices(other) + return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices) else: return NotImplemented @@ -369,8 +367,21 @@ class IndexData(metaclass=index_data_ops_creator): def columns(self): return self.indices[1] - def _align_indices(self, other): - """Align index before performing the four arithmetic operations.""" + def _align_indices(self, other: "IndexData") -> "IndexData": + """ + Align all indices of `other` to `self` before performing the arithmetic operations. + This function will return a new IndexData rather than changing data in `other` inplace + + Parameters + ---------- + other : "IndexData" + the index in `other` is to be chagned + + Returns + ------- + IndexData: + the data in `other` with index aligned to `self` + """ raise NotImplementedError(f"please implement _align_indices func") def sort_index(self, axis=0, inplace=True): @@ -387,12 +398,12 @@ class IndexData(metaclass=index_data_ops_creator): tmp_data = np.absolute(self.data) return self.__class__(tmp_data, *self.indices) - def replace(self, to_replace: dict): + def replace(self, to_replace: Dict[np.number, np.number]): assert isinstance(to_replace, dict) tmp_data = self.data.copy() for num in to_replace: if num in tmp_data: - tmp_data[tmp_data == num] = to_replace[num] + tmp_data[self.data == num] = to_replace[num] return self.__class__(tmp_data, *self.indices) def apply(self, func: Callable): @@ -411,6 +422,7 @@ class IndexData(metaclass=index_data_ops_creator): return len(self.data) def sum(self, axis=None): + # FIXME: weird logic and not general if axis is None: return np.nansum(self.data) elif axis == 0: @@ -423,6 +435,7 @@ class IndexData(metaclass=index_data_ops_creator): raise ValueError(f"axis must be None, 0 or 1") def mean(self, axis=None): + # FIXME: weird logic and not general if axis is None: return np.nanmean(self.data) elif axis == 0: @@ -479,9 +492,9 @@ class SingleData(IndexData): def _align_indices(self, other): if self.index == other.index: - return self, other + return other elif set(self.index) == set(other.index): - return self, other.reindex(self.index) + return other.reindex(self.index) else: raise ValueError( f"The indexes of self and other do not meet the requirements of the four arithmetic operations" @@ -563,7 +576,7 @@ class MultiData(IndexData): def _align_indices(self, other): if self.indices == other.indices: - return self, other + return other else: raise ValueError( f"The indexes of self and other do not meet the requirements of the four arithmetic operations" diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index caa9b1897..c7a80fb0f 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -89,6 +89,12 @@ class IndexDataTest(unittest.TestCase): with self.assertRaises(KeyError): sd.loc["foo"] + # replace + sd = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) + sd = sd.replace(dict(zip(range(1, 5), range(2, 6)))) + print(sd) + self.assertTrue(sd.iloc[0] == 2) + def test_ops(self): sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])