mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
fix bugs
This commit is contained in:
@@ -167,11 +167,14 @@ class CN1minNumpyQuote(BaseQuote):
|
||||
|
||||
def _agg_data(self, data: IndexData, method):
|
||||
"""Agg data by specific method."""
|
||||
# FIXME: why not call the method of data directly?
|
||||
if method == "sum":
|
||||
return np.nansum(data)
|
||||
elif method == "mean":
|
||||
return np.nanmean(data)
|
||||
elif method == "last":
|
||||
# FIXME: I've never seen that this method was called.
|
||||
# Please merge it with "ts_data_last"
|
||||
return data[-1]
|
||||
elif method == "all":
|
||||
return data.all()
|
||||
|
||||
@@ -10,7 +10,7 @@ Motivation of index_data
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Union, Callable, List
|
||||
from typing import Dict, Tuple, Union, Callable, List
|
||||
import bisect
|
||||
|
||||
import numpy as np
|
||||
@@ -128,8 +128,7 @@ class Index:
|
||||
raise KeyError(f"{item} can't be found in {self}")
|
||||
|
||||
def __or__(self, other: "Index"):
|
||||
idx = Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
|
||||
return idx
|
||||
return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
|
||||
|
||||
def __eq__(self, other: "Index"):
|
||||
# NOTE: np.nan is not supported in the index
|
||||
@@ -283,9 +282,8 @@ class BinaryOps:
|
||||
if isinstance(other, (int, float, np.number)):
|
||||
return self.obj.__class__(self_data_method(other))
|
||||
elif isinstance(other, self.obj.__class__):
|
||||
# TODO: bad interface
|
||||
tmp_data1, tmp_data2 = self.obj._align_indices(other)
|
||||
return self.obj.__class__(self_data_method(tmp_data2.data), *self.obj.indices)
|
||||
other_aligned = self.obj._align_indices(other)
|
||||
return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@@ -369,8 +367,21 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
def columns(self):
|
||||
return self.indices[1]
|
||||
|
||||
def _align_indices(self, other):
|
||||
"""Align index before performing the four arithmetic operations."""
|
||||
def _align_indices(self, other: "IndexData") -> "IndexData":
|
||||
"""
|
||||
Align all indices of `other` to `self` before performing the arithmetic operations.
|
||||
This function will return a new IndexData rather than changing data in `other` inplace
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : "IndexData"
|
||||
the index in `other` is to be chagned
|
||||
|
||||
Returns
|
||||
-------
|
||||
IndexData:
|
||||
the data in `other` with index aligned to `self`
|
||||
"""
|
||||
raise NotImplementedError(f"please implement _align_indices func")
|
||||
|
||||
def sort_index(self, axis=0, inplace=True):
|
||||
@@ -387,12 +398,12 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
tmp_data = np.absolute(self.data)
|
||||
return self.__class__(tmp_data, *self.indices)
|
||||
|
||||
def replace(self, to_replace: dict):
|
||||
def replace(self, to_replace: Dict[np.number, np.number]):
|
||||
assert isinstance(to_replace, dict)
|
||||
tmp_data = self.data.copy()
|
||||
for num in to_replace:
|
||||
if num in tmp_data:
|
||||
tmp_data[tmp_data == num] = to_replace[num]
|
||||
tmp_data[self.data == num] = to_replace[num]
|
||||
return self.__class__(tmp_data, *self.indices)
|
||||
|
||||
def apply(self, func: Callable):
|
||||
@@ -411,6 +422,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
return len(self.data)
|
||||
|
||||
def sum(self, axis=None):
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
elif axis == 0:
|
||||
@@ -423,6 +435,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
raise ValueError(f"axis must be None, 0 or 1")
|
||||
|
||||
def mean(self, axis=None):
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nanmean(self.data)
|
||||
elif axis == 0:
|
||||
@@ -479,9 +492,9 @@ class SingleData(IndexData):
|
||||
|
||||
def _align_indices(self, other):
|
||||
if self.index == other.index:
|
||||
return self, other
|
||||
return other
|
||||
elif set(self.index) == set(other.index):
|
||||
return self, other.reindex(self.index)
|
||||
return other.reindex(self.index)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
@@ -563,7 +576,7 @@ class MultiData(IndexData):
|
||||
|
||||
def _align_indices(self, other):
|
||||
if self.indices == other.indices:
|
||||
return self, other
|
||||
return other
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
|
||||
@@ -89,6 +89,12 @@ class IndexDataTest(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
sd.loc["foo"]
|
||||
|
||||
# replace
|
||||
sd = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
sd = sd.replace(dict(zip(range(1, 5), range(2, 6))))
|
||||
print(sd)
|
||||
self.assertTrue(sd.iloc[0] == 2)
|
||||
|
||||
def test_ops(self):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
|
||||
Reference in New Issue
Block a user