1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

fix index data bug

This commit is contained in:
wangwenxi.handsome
2021-08-18 14:18:19 +00:00
committed by you-n-g
parent 16b954866f
commit e134c358fd
2 changed files with 67 additions and 87 deletions

View File

@@ -203,7 +203,7 @@ class NumpyQuote(BaseQuote):
elif method is None:
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
stock_dates = self.dates_list[stock_id][start_id:end_id].to_list()
return IndexData(stock_data, [stock_id], stock_dates)
return IndexData(stock_data, stock_dates)
else:
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
@@ -721,7 +721,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
self.data = np.zeros((len(NumpyOrderIndicator.ROW), len(metric)))
self.column = list(metric.keys())
self.column_map = dict(zip(self.column, range(len(self.column))))
metric_column = list(metric.keys())
if self.column != metric_column:
assert len(set(self.column) - set(metric_column)) == 0
@@ -753,9 +753,9 @@ class NumpyOrderIndicator(BaseOrderIndicator):
def get_index_data(self, metric):
if self._if_valid_metric(metric):
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], [metric], self.column)
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], self.column)
else:
return IndexData([], [], [])
return IndexData([], [])
def get_metric_series(self, metric: str) -> Union[pd.Series]:
if self._if_valid_metric(metric):
@@ -819,52 +819,80 @@ class NumpyOrderIndicator(BaseOrderIndicator):
class IndexData:
def __init__(self, data, row, column):
def __init__(self, data, column):
if isinstance(data, list):
self.data = np.array([data])
self.data = np.array(data)
elif isinstance(data, np.ndarray):
if data.ndim == 1:
self.data = data[np.newaxis, :]
elif data.ndim == 2:
self.data = data
else:
raise ValueError(f"the dimension of data must <= 2")
self.data = data
else:
raise ValueError(f"data must be list or np.ndarray")
self.data = data
self.ndim = self.data.ndim
assert isinstance(row, list)
self.row = row
self.row_map = dict(zip(self.row, range(len(self.row))))
assert isinstance(column, list)
self.col = column
self.col_map = dict(zip(self.col, range(len(self.col))))
def reindex(self, new_column):
tmp_data = self.data.copy()
for row_id, row in enumerate(self.row):
for col_id, col in new_column:
if col in self.col:
tmp_data[row_id, col_id] = self.data[row_id, self.row_map[col]]
else:
tmp_data[row_id, col_id] = np.NaN
return IndexData(tmp_data, self.row, list(new_column))
assert self.ndim == 1
tmp_data = np.full(len(new_column), np.NaN)
for col_id, col in enumerate(new_column):
if col in self.col:
tmp_data[col_id] = self.data[self.col_map[col]]
return IndexData(tmp_data, list(new_column))
def to_dict(self):
assert len(self.row) == 1
if self.data.size == 0:
return {col: np.NaN for col in self.col}
assert self.ndim == 1
return dict(zip(self.col, self.data.tolist()))
def keep_positive(self, limit=1e-08):
assert self.ndim == 1
new_col = []
new_data = []
for col_id, col in enumerate(self.col):
if self.data[col_id] < 1e-08:
continue
else:
new_col.append(col)
new_data.append(self.data[col_id])
return IndexData(new_data, new_col)
def sum(self, axis=None):
if axis is None:
return np.nansum(self.data)
if axis == 0:
assert self.ndim == 2
tmp_data = np.nansum(self.data, axis=0)
return IndexData(tmp_data, self.col)
else:
return dict(zip(self.col, self.data[0, :].tolist()))
raise NotImplementedError(f"axis must be 0 or None")
def __mul__(self, other):
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.col == other.col
assert len(self.data) == len(other.data)
return IndexData(self.data * other.data, self.col)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.col == other.col
assert len(self.data) == len(other.data)
return IndexData(self.data / other.data, self.col)
else:
return NotImplemented
def __len__(self):
return len(self.col)
@staticmethod
def concat_by_col(index_data_list):
# get all col and row
all_col = set()
all_row = []
for index_data in index_data_list:
all_col = all_col | set(index_data.col)
all_row.append(index_data.row[0])
all_col = list(all_col)
all_col.sort()
all_col_map = dict(zip(all_col, range(len(all_col))))
@@ -874,52 +902,4 @@ class IndexData:
for data_id, index_data in enumerate(index_data_list):
now_data_map = [all_col_map[col] for col in index_data.col]
tmp_data[data_id, now_data_map] = index_data.data
return IndexData(tmp_data, all_row, all_col)
def sum(self, axis = None):
if axis is None:
return np.nansum(self.data)
if axis == 0:
tmp_data = np.nansum(self.data, axis=0)
return IndexData(tmp_data, [self.row[0]], self.col)
else:
raise NotImplementedError(f"axis must be 0 or None")
def keep_positive(self, limit = 1e-08):
assert len(self.row) == 1
new_col = []
new_data = []
for col_id, col in enumerate(self.col):
if self.data[0: col_id] < 1e-08:
continue
else:
new_col.append(col)
new_data.append(self.data[0: col_id])
return IndexData(new_data, self.row, new_col)
def __mul__(self, other):
if isinstance(other, IndexData):
assert len(self.row) == len(other.row)
assert self.col == other.col
return IndexData(self.data * other.data, ["mul"], self.col)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, IndexData):
assert len(self.row) == len(other.row)
assert self.col == other.col
return IndexData(self.data / other.data, ["div"], self.col)
else:
return NotImplemented
def __len__(self):
return len(self.col)
return IndexData(tmp_data, all_col)

View File

@@ -391,7 +391,7 @@ class Indicator:
return None, None
if isinstance(price_s, (int, float)):
price_s = IndexData([price_s], [inst], [trade_start_time])
price_s = IndexData([price_s], [trade_start_time])
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
@@ -402,16 +402,15 @@ class Indicator:
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float)):
volume_s = IndexData([volume_s], [inst], [trade_start_time])
volume_s = IndexData([volume_s], [trade_start_time])
volume_s = volume_s.reindex(price_s.col)
elif agg == "twap":
volume_s = IndexData([1 for i in range(price_s.col)], [inst], price_s.col)
volume_s = IndexData([1 for i in range(len(price_s.col))], price_s.col)
else:
raise NotImplementedError(f"This type of input is not supported")
base_volume = volume_s.sum()
base_price = (price_s * volume_s).sum() / base_volume
return base_price, base_volume
def _agg_base_price(
@@ -451,6 +450,7 @@ class Indicator:
for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
bp_s = oi.get_index_data("base_price").reindex(trade_dir.col)
bv_s = oi.get_index_data("base_volume").reindex(trade_dir.col)
bp_new, bv_new = {}, {}
for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.col, trade_dir.data)):
if np.isnan(pr):
@@ -468,16 +468,16 @@ class Indicator:
else:
bp_new[inst], bv_new[inst] = pr, v
bp_new = IndexData(list(bp_new.values()), ["base_price"], list(bp_new.keys()))
bv_new = IndexData(list(bv_new.values()), ["base_volume"], list(bv_new.keys()))
bp_new = IndexData(list(bp_new.values()), list(bp_new.keys()))
bv_new = IndexData(list(bv_new.values()), list(bv_new.keys()))
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = IndexData.concat_by_col(bp_all)
bv_all = IndexData.concat_by_col(bv_all)
base_volume = bv_all.sum(axis = 0)
base_volume = bv_all.sum(axis=0)
self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis = 0) / base_volume).to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=0) / base_volume).to_dict())
def _agg_order_price_advantage(self):
def if_empty_func(trade_price):