1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00
This commit is contained in:
Young
2021-08-31 09:52:32 +00:00
committed by you-n-g
parent 9a74471ab6
commit 5f0ee6ce68
3 changed files with 35 additions and 13 deletions

View File

@@ -167,11 +167,14 @@ class CN1minNumpyQuote(BaseQuote):
def _agg_data(self, data: IndexData, method):
"""Agg data by specific method."""
# FIXME: why not call the method of data directly?
if method == "sum":
return np.nansum(data)
elif method == "mean":
return np.nanmean(data)
elif method == "last":
# FIXME: I've never seen that this method was called.
# Please merge it with "ts_data_last"
return data[-1]
elif method == "all":
return data.all()

View File

@@ -10,7 +10,7 @@ Motivation of index_data
"""
from functools import partial
from typing import Tuple, Union, Callable, List
from typing import Dict, Tuple, Union, Callable, List
import bisect
import numpy as np
@@ -128,8 +128,7 @@ class Index:
raise KeyError(f"{item} can't be found in {self}")
def __or__(self, other: "Index"):
idx = Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
return idx
return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
def __eq__(self, other: "Index"):
# NOTE: np.nan is not supported in the index
@@ -283,9 +282,8 @@ class BinaryOps:
if isinstance(other, (int, float, np.number)):
return self.obj.__class__(self_data_method(other))
elif isinstance(other, self.obj.__class__):
# TODO: bad interface
tmp_data1, tmp_data2 = self.obj._align_indices(other)
return self.obj.__class__(self_data_method(tmp_data2.data), *self.obj.indices)
other_aligned = self.obj._align_indices(other)
return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)
else:
return NotImplemented
@@ -369,8 +367,21 @@ class IndexData(metaclass=index_data_ops_creator):
def columns(self):
return self.indices[1]
def _align_indices(self, other):
"""Align index before performing the four arithmetic operations."""
def _align_indices(self, other: "IndexData") -> "IndexData":
"""
Align all indices of `other` to `self` before performing the arithmetic operations.
This function will return a new IndexData rather than changing data in `other` inplace
Parameters
----------
other : "IndexData"
the index in `other` is to be chagned
Returns
-------
IndexData:
the data in `other` with index aligned to `self`
"""
raise NotImplementedError(f"please implement _align_indices func")
def sort_index(self, axis=0, inplace=True):
@@ -387,12 +398,12 @@ class IndexData(metaclass=index_data_ops_creator):
tmp_data = np.absolute(self.data)
return self.__class__(tmp_data, *self.indices)
def replace(self, to_replace: dict):
def replace(self, to_replace: Dict[np.number, np.number]):
assert isinstance(to_replace, dict)
tmp_data = self.data.copy()
for num in to_replace:
if num in tmp_data:
tmp_data[tmp_data == num] = to_replace[num]
tmp_data[self.data == num] = to_replace[num]
return self.__class__(tmp_data, *self.indices)
def apply(self, func: Callable):
@@ -411,6 +422,7 @@ class IndexData(metaclass=index_data_ops_creator):
return len(self.data)
def sum(self, axis=None):
# FIXME: weird logic and not general
if axis is None:
return np.nansum(self.data)
elif axis == 0:
@@ -423,6 +435,7 @@ class IndexData(metaclass=index_data_ops_creator):
raise ValueError(f"axis must be None, 0 or 1")
def mean(self, axis=None):
# FIXME: weird logic and not general
if axis is None:
return np.nanmean(self.data)
elif axis == 0:
@@ -479,9 +492,9 @@ class SingleData(IndexData):
def _align_indices(self, other):
if self.index == other.index:
return self, other
return other
elif set(self.index) == set(other.index):
return self, other.reindex(self.index)
return other.reindex(self.index)
else:
raise ValueError(
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
@@ -563,7 +576,7 @@ class MultiData(IndexData):
def _align_indices(self, other):
if self.indices == other.indices:
return self, other
return other
else:
raise ValueError(
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"

View File

@@ -89,6 +89,12 @@ class IndexDataTest(unittest.TestCase):
with self.assertRaises(KeyError):
sd.loc["foo"]
# replace
sd = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
sd = sd.replace(dict(zip(range(1, 5), range(2, 6))))
print(sd)
self.assertTrue(sd.iloc[0] == 2)
def test_ops(self):
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])