mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
Fix the aggregation function of IndexData
This commit is contained in:
@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
|
||||
if is_single_value(start_time, end_time, self.freq, self.region):
|
||||
# this is a very special case.
|
||||
# skip aggregating function to speed-up the query calculation
|
||||
|
||||
# FIXME:
|
||||
# it will go to the else logic when it comes to the
|
||||
# 1) the day before holiday when daily trading
|
||||
# 2) the last minute of the day when intraday trading
|
||||
try:
|
||||
return self.data[stock_id].loc[start_time, field]
|
||||
except KeyError:
|
||||
|
||||
@@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
def columns(self):
|
||||
return self.indices[1]
|
||||
|
||||
def __getitem__(self, args):
|
||||
# NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean
|
||||
return self.iloc[args]
|
||||
|
||||
def _align_indices(self, other: "IndexData") -> "IndexData":
|
||||
"""
|
||||
Align all indices of `other` to `self` before performing the arithmetic operations.
|
||||
@@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
Parameters
|
||||
----------
|
||||
other : "IndexData"
|
||||
the index in `other` is to be chagned
|
||||
the index in `other` is to be changed
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
"""
|
||||
return len(self.data)
|
||||
|
||||
def sum(self, axis=None):
|
||||
def sum(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
@@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
else:
|
||||
raise ValueError(f"axis must be None, 0 or 1")
|
||||
|
||||
def mean(self, axis=None):
|
||||
def mean(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nanmean(self.data)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
import unittest
|
||||
@@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase):
|
||||
# sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
# 2 * sd2
|
||||
|
||||
def test_squeeze(self):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
# automatically squeezing
|
||||
self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData))
|
||||
self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData))
|
||||
self.assertTrue(not isinstance(sd1.sum(), idd.IndexData))
|
||||
self.assertEqual(np.nansum(sd1), 10)
|
||||
self.assertEqual(np.sum(sd1), 10)
|
||||
self.assertEqual(sd1.sum(), 10)
|
||||
self.assertEqual(np.nanmean(sd1), 2.5)
|
||||
self.assertEqual(np.mean(sd1), 2.5)
|
||||
self.assertEqual(sd1.mean(), 2.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user