From f71b0c11894c19bfad6026efdc6c39cc8b746a98 Mon Sep 17 00:00:00 2001 From: "wangwenxi.handsome" Date: Wed, 1 Sep 2021 16:20:52 +0000 Subject: [PATCH] 250s --- qlib/backtest/high_performance_ds.py | 16 ++++++++--- qlib/utils/index_data.py | 41 +++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 86d631bfa..97310ffb6 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -604,13 +604,21 @@ class NumpyOrderIndicator(BaseOrderIndicator): @staticmethod def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): + # get all index(stock_id) + stocks = set() + for indicator in indicators: + # set(np.ndarray.tolist()) is faster than set(np.ndarray) + stocks = stocks | set(indicator.data[metrics[0]].index.tolist()) + stocks = list(stocks) + stocks.sort() + + # add metric by index if isinstance(metrics, str): metrics = [metrics] for metric in metrics: - tmp_metric = idd.SingleData() - for indicator in indicators: - tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) - order_indicator.data[metric] = tmp_metric + order_indicator.data[metric] = idd.sum_by_index( + [indicator.data[metric] for indicator in indicators], stocks, fill_value + ) def __repr__(self): return repr(self.data) diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 52cc385e0..9bd059add 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -22,7 +22,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData": Parameters ---------- - index_data_list : List[SingleData] + data_list : List[SingleData] the list of all SingleData to concat. Returns @@ -52,6 +52,36 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData": raise ValueError(f"axis must be 0 or 1") +def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData": + """concat all SingleData by new index. + + Parameters + ---------- + data_list : List[SingleData] + the list of all SingleData to sum. + new_index : list + the new_index of new SingleData. + fill_value : float + fill the missing values ​​or replace np.NaN. + + Returns + ------- + SingleData + the SingleData with new_index and values after sum. + """ + data_list = [data.to_dict() for data in data_list] + data_sum = {} + for id in new_index: + item_sum = 0 + for data in data_list: + if id in data and data[id] != np.NaN: + item_sum += data[id] + else: + item_sum += fill_value + data_sum[id] = item_sum + return SingleData(data_sum) + + class Index: """ This is for indexing(rows or columns) @@ -155,6 +185,10 @@ class Index: idx._is_sorted = True return idx, sorted_idx + def tolist(self): + """return the index with the format of list.""" + return self.idx_list.tolist() + class LocIndexer: """ @@ -529,8 +563,7 @@ class SingleData(IndexData): tmp_data = np.full(len(index), fill_value, dtype=np.float64) for index_id, index_item in enumerate(index): try: - item_data = self.loc[index_item] - tmp_data[index_id] = item_data if item_data != np.NaN else fill_value + tmp_data[index_id] = self.loc[index_item] except KeyError: pass return SingleData(tmp_data, index) @@ -542,7 +575,7 @@ class SingleData(IndexData): common_index, _ = common_index.sort() tmp_data1 = self.reindex(common_index, fill_value) tmp_data2 = other.reindex(common_index, fill_value) - return tmp_data1 + tmp_data2 + return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value) def to_dict(self): """convert SingleData to dict.