mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 09:31:18 +08:00
250s
This commit is contained in:
committed by
you-n-g
parent
919380597b
commit
f71b0c1189
@@ -604,13 +604,21 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
# get all index(stock_id)
|
||||
stocks = set()
|
||||
for indicator in indicators:
|
||||
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
||||
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = list(stocks)
|
||||
stocks.sort()
|
||||
|
||||
# add metric by index
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
tmp_metric = idd.SingleData()
|
||||
for indicator in indicators:
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
order_indicator.data[metric] = idd.sum_by_index(
|
||||
[indicator.data[metric] for indicator in indicators], stocks, fill_value
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.data)
|
||||
|
||||
@@ -22,7 +22,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_data_list : List[SingleData]
|
||||
data_list : List[SingleData]
|
||||
the list of all SingleData to concat.
|
||||
|
||||
Returns
|
||||
@@ -52,6 +52,36 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
||||
raise ValueError(f"axis must be 0 or 1")
|
||||
|
||||
|
||||
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
|
||||
"""concat all SingleData by new index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_list : List[SingleData]
|
||||
the list of all SingleData to sum.
|
||||
new_index : list
|
||||
the new_index of new SingleData.
|
||||
fill_value : float
|
||||
fill the missing values or replace np.NaN.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SingleData
|
||||
the SingleData with new_index and values after sum.
|
||||
"""
|
||||
data_list = [data.to_dict() for data in data_list]
|
||||
data_sum = {}
|
||||
for id in new_index:
|
||||
item_sum = 0
|
||||
for data in data_list:
|
||||
if id in data and data[id] != np.NaN:
|
||||
item_sum += data[id]
|
||||
else:
|
||||
item_sum += fill_value
|
||||
data_sum[id] = item_sum
|
||||
return SingleData(data_sum)
|
||||
|
||||
|
||||
class Index:
|
||||
"""
|
||||
This is for indexing(rows or columns)
|
||||
@@ -155,6 +185,10 @@ class Index:
|
||||
idx._is_sorted = True
|
||||
return idx, sorted_idx
|
||||
|
||||
def tolist(self):
|
||||
"""return the index with the format of list."""
|
||||
return self.idx_list.tolist()
|
||||
|
||||
|
||||
class LocIndexer:
|
||||
"""
|
||||
@@ -529,8 +563,7 @@ class SingleData(IndexData):
|
||||
tmp_data = np.full(len(index), fill_value, dtype=np.float64)
|
||||
for index_id, index_item in enumerate(index):
|
||||
try:
|
||||
item_data = self.loc[index_item]
|
||||
tmp_data[index_id] = item_data if item_data != np.NaN else fill_value
|
||||
tmp_data[index_id] = self.loc[index_item]
|
||||
except KeyError:
|
||||
pass
|
||||
return SingleData(tmp_data, index)
|
||||
@@ -542,7 +575,7 @@ class SingleData(IndexData):
|
||||
common_index, _ = common_index.sort()
|
||||
tmp_data1 = self.reindex(common_index, fill_value)
|
||||
tmp_data2 = other.reindex(common_index, fill_value)
|
||||
return tmp_data1 + tmp_data2
|
||||
return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)
|
||||
|
||||
def to_dict(self):
|
||||
"""convert SingleData to dict.
|
||||
|
||||
Reference in New Issue
Block a user