1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 09:31:18 +08:00
This commit is contained in:
wangwenxi.handsome
2021-09-01 16:20:52 +00:00
committed by you-n-g
parent 919380597b
commit f71b0c1189
2 changed files with 49 additions and 8 deletions

View File

@@ -604,13 +604,21 @@ class NumpyOrderIndicator(BaseOrderIndicator):
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
# get all index(stock_id)
stocks = set()
for indicator in indicators:
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
stocks = list(stocks)
stocks.sort()
# add metric by index
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
tmp_metric = idd.SingleData()
for indicator in indicators:
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.data[metric] = tmp_metric
order_indicator.data[metric] = idd.sum_by_index(
[indicator.data[metric] for indicator in indicators], stocks, fill_value
)
def __repr__(self):
return repr(self.data)

View File

@@ -22,7 +22,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
Parameters
----------
index_data_list : List[SingleData]
data_list : List[SingleData]
the list of all SingleData to concat.
Returns
@@ -52,6 +52,36 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
raise ValueError(f"axis must be 0 or 1")
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
"""concat all SingleData by new index.
Parameters
----------
data_list : List[SingleData]
the list of all SingleData to sum.
new_index : list
the new_index of new SingleData.
fill_value : float
fill the missing values or replace np.NaN.
Returns
-------
SingleData
the SingleData with new_index and values after sum.
"""
data_list = [data.to_dict() for data in data_list]
data_sum = {}
for id in new_index:
item_sum = 0
for data in data_list:
if id in data and data[id] != np.NaN:
item_sum += data[id]
else:
item_sum += fill_value
data_sum[id] = item_sum
return SingleData(data_sum)
class Index:
"""
This is for indexing(rows or columns)
@@ -155,6 +185,10 @@ class Index:
idx._is_sorted = True
return idx, sorted_idx
def tolist(self):
"""return the index with the format of list."""
return self.idx_list.tolist()
class LocIndexer:
"""
@@ -529,8 +563,7 @@ class SingleData(IndexData):
tmp_data = np.full(len(index), fill_value, dtype=np.float64)
for index_id, index_item in enumerate(index):
try:
item_data = self.loc[index_item]
tmp_data[index_id] = item_data if item_data != np.NaN else fill_value
tmp_data[index_id] = self.loc[index_item]
except KeyError:
pass
return SingleData(tmp_data, index)
@@ -542,7 +575,7 @@ class SingleData(IndexData):
common_index, _ = common_index.sort()
tmp_data1 = self.reindex(common_index, fill_value)
tmp_data2 = other.reindex(common_index, fill_value)
return tmp_data1 + tmp_data2
return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)
def to_dict(self):
"""convert SingleData to dict.