From 64130d9407c38d1450e3ad72ebc2e033092b79f6 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 22 Oct 2021 15:20:45 +0800 Subject: [PATCH] Fix the aggregation function of IndexData --- qlib/backtest/high_performance_ds.py | 5 +++++ qlib/utils/index_data.py | 12 +++++++++--- tests/misc/test_index_data.py | 14 +++++++++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 235bd054b..51847cac3 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote): if is_single_value(start_time, end_time, self.freq, self.region): # this is a very special case. # skip aggregating function to speed-up the query calculation + + # FIXME: + # it will go to the else logic when it comes to the + # 1) the day before holiday when daily trading + # 2) the last minute of the day when intraday trading try: return self.data[stock_id].loc[start_time, field] except KeyError: diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 5e3942db5..06fb42a5e 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator): def columns(self): return self.indices[1] + def __getitem__(self, args): + # NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean + return self.iloc[args] + def _align_indices(self, other: "IndexData") -> "IndexData": """ Align all indices of `other` to `self` before performing the arithmetic operations. @@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator): Parameters ---------- other : "IndexData" - the index in `other` is to be chagned + the index in `other` is to be changed Returns ------- @@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator): """ return len(self.data) - def sum(self, axis=None): + def sum(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nansum(self.data) @@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator): else: raise ValueError(f"axis must be None, 0 or 1") - def mean(self, axis=None): + def mean(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nanmean(self.data) diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 3cd819a0f..20cda69ff 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd - import qlib.utils.index_data as idd import unittest @@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase): # sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) # 2 * sd2 + def test_squeeze(self): + sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) + # automatically squeezing + self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(sd1.sum(), idd.IndexData)) + self.assertEqual(np.nansum(sd1), 10) + self.assertEqual(np.sum(sd1), 10) + self.assertEqual(sd1.sum(), 10) + self.assertEqual(np.nanmean(sd1), 2.5) + self.assertEqual(np.mean(sd1), 2.5) + self.assertEqual(sd1.mean(), 2.5) + if __name__ == "__main__": unittest.main()