mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
fix index data bug
This commit is contained in:
committed by
you-n-g
parent
16b954866f
commit
e134c358fd
@@ -203,7 +203,7 @@ class NumpyQuote(BaseQuote):
|
||||
elif method is None:
|
||||
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
|
||||
stock_dates = self.dates_list[stock_id][start_id:end_id].to_list()
|
||||
return IndexData(stock_data, [stock_id], stock_dates)
|
||||
return IndexData(stock_data, stock_dates)
|
||||
else:
|
||||
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
|
||||
|
||||
@@ -721,7 +721,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
self.data = np.zeros((len(NumpyOrderIndicator.ROW), len(metric)))
|
||||
self.column = list(metric.keys())
|
||||
self.column_map = dict(zip(self.column, range(len(self.column))))
|
||||
|
||||
|
||||
metric_column = list(metric.keys())
|
||||
if self.column != metric_column:
|
||||
assert len(set(self.column) - set(metric_column)) == 0
|
||||
@@ -753,9 +753,9 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
def get_index_data(self, metric):
|
||||
if self._if_valid_metric(metric):
|
||||
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], [metric], self.column)
|
||||
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], self.column)
|
||||
else:
|
||||
return IndexData([], [], [])
|
||||
return IndexData([], [])
|
||||
|
||||
def get_metric_series(self, metric: str) -> Union[pd.Series]:
|
||||
if self._if_valid_metric(metric):
|
||||
@@ -819,52 +819,80 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
|
||||
class IndexData:
|
||||
def __init__(self, data, row, column):
|
||||
def __init__(self, data, column):
|
||||
if isinstance(data, list):
|
||||
self.data = np.array([data])
|
||||
self.data = np.array(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.ndim == 1:
|
||||
self.data = data[np.newaxis, :]
|
||||
elif data.ndim == 2:
|
||||
self.data = data
|
||||
else:
|
||||
raise ValueError(f"the dimension of data must <= 2")
|
||||
self.data = data
|
||||
else:
|
||||
raise ValueError(f"data must be list or np.ndarray")
|
||||
self.data = data
|
||||
self.ndim = self.data.ndim
|
||||
|
||||
assert isinstance(row, list)
|
||||
self.row = row
|
||||
self.row_map = dict(zip(self.row, range(len(self.row))))
|
||||
assert isinstance(column, list)
|
||||
self.col = column
|
||||
self.col_map = dict(zip(self.col, range(len(self.col))))
|
||||
|
||||
def reindex(self, new_column):
|
||||
tmp_data = self.data.copy()
|
||||
for row_id, row in enumerate(self.row):
|
||||
for col_id, col in new_column:
|
||||
if col in self.col:
|
||||
tmp_data[row_id, col_id] = self.data[row_id, self.row_map[col]]
|
||||
else:
|
||||
tmp_data[row_id, col_id] = np.NaN
|
||||
return IndexData(tmp_data, self.row, list(new_column))
|
||||
assert self.ndim == 1
|
||||
tmp_data = np.full(len(new_column), np.NaN)
|
||||
for col_id, col in enumerate(new_column):
|
||||
if col in self.col:
|
||||
tmp_data[col_id] = self.data[self.col_map[col]]
|
||||
return IndexData(tmp_data, list(new_column))
|
||||
|
||||
def to_dict(self):
|
||||
assert len(self.row) == 1
|
||||
if self.data.size == 0:
|
||||
return {col: np.NaN for col in self.col}
|
||||
assert self.ndim == 1
|
||||
return dict(zip(self.col, self.data.tolist()))
|
||||
|
||||
def keep_positive(self, limit=1e-08):
|
||||
assert self.ndim == 1
|
||||
new_col = []
|
||||
new_data = []
|
||||
for col_id, col in enumerate(self.col):
|
||||
if self.data[col_id] < 1e-08:
|
||||
continue
|
||||
else:
|
||||
new_col.append(col)
|
||||
new_data.append(self.data[col_id])
|
||||
return IndexData(new_data, new_col)
|
||||
|
||||
def sum(self, axis=None):
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
if axis == 0:
|
||||
assert self.ndim == 2
|
||||
tmp_data = np.nansum(self.data, axis=0)
|
||||
return IndexData(tmp_data, self.col)
|
||||
else:
|
||||
return dict(zip(self.col, self.data[0, :].tolist()))
|
||||
raise NotImplementedError(f"axis must be 0 or None")
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert self.ndim == other.ndim
|
||||
assert self.col == other.col
|
||||
assert len(self.data) == len(other.data)
|
||||
return IndexData(self.data * other.data, self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert self.ndim == other.ndim
|
||||
assert self.col == other.col
|
||||
assert len(self.data) == len(other.data)
|
||||
return IndexData(self.data / other.data, self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __len__(self):
|
||||
return len(self.col)
|
||||
|
||||
@staticmethod
|
||||
def concat_by_col(index_data_list):
|
||||
# get all col and row
|
||||
all_col = set()
|
||||
all_row = []
|
||||
for index_data in index_data_list:
|
||||
all_col = all_col | set(index_data.col)
|
||||
all_row.append(index_data.row[0])
|
||||
all_col = list(all_col)
|
||||
all_col.sort()
|
||||
all_col_map = dict(zip(all_col, range(len(all_col))))
|
||||
@@ -874,52 +902,4 @@ class IndexData:
|
||||
for data_id, index_data in enumerate(index_data_list):
|
||||
now_data_map = [all_col_map[col] for col in index_data.col]
|
||||
tmp_data[data_id, now_data_map] = index_data.data
|
||||
return IndexData(tmp_data, all_row, all_col)
|
||||
|
||||
def sum(self, axis = None):
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
if axis == 0:
|
||||
tmp_data = np.nansum(self.data, axis=0)
|
||||
return IndexData(tmp_data, [self.row[0]], self.col)
|
||||
else:
|
||||
raise NotImplementedError(f"axis must be 0 or None")
|
||||
|
||||
def keep_positive(self, limit = 1e-08):
|
||||
assert len(self.row) == 1
|
||||
new_col = []
|
||||
new_data = []
|
||||
for col_id, col in enumerate(self.col):
|
||||
if self.data[0: col_id] < 1e-08:
|
||||
continue
|
||||
else:
|
||||
new_col.append(col)
|
||||
new_data.append(self.data[0: col_id])
|
||||
return IndexData(new_data, self.row, new_col)
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert len(self.row) == len(other.row)
|
||||
assert self.col == other.col
|
||||
return IndexData(self.data * other.data, ["mul"], self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert len(self.row) == len(other.row)
|
||||
assert self.col == other.col
|
||||
return IndexData(self.data / other.data, ["div"], self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __len__(self):
|
||||
return len(self.col)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return IndexData(tmp_data, all_col)
|
||||
|
||||
@@ -391,7 +391,7 @@ class Indicator:
|
||||
return None, None
|
||||
|
||||
if isinstance(price_s, (int, float)):
|
||||
price_s = IndexData([price_s], [inst], [trade_start_time])
|
||||
price_s = IndexData([price_s], [trade_start_time])
|
||||
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
@@ -402,16 +402,15 @@ class Indicator:
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
if isinstance(volume_s, (int, float)):
|
||||
volume_s = IndexData([volume_s], [inst], [trade_start_time])
|
||||
volume_s = IndexData([volume_s], [trade_start_time])
|
||||
volume_s = volume_s.reindex(price_s.col)
|
||||
elif agg == "twap":
|
||||
volume_s = IndexData([1 for i in range(price_s.col)], [inst], price_s.col)
|
||||
volume_s = IndexData([1 for i in range(len(price_s.col))], price_s.col)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
base_volume = volume_s.sum()
|
||||
base_price = (price_s * volume_s).sum() / base_volume
|
||||
|
||||
return base_price, base_volume
|
||||
|
||||
def _agg_base_price(
|
||||
@@ -451,6 +450,7 @@ class Indicator:
|
||||
for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
|
||||
bp_s = oi.get_index_data("base_price").reindex(trade_dir.col)
|
||||
bv_s = oi.get_index_data("base_volume").reindex(trade_dir.col)
|
||||
|
||||
bp_new, bv_new = {}, {}
|
||||
for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.col, trade_dir.data)):
|
||||
if np.isnan(pr):
|
||||
@@ -468,16 +468,16 @@ class Indicator:
|
||||
else:
|
||||
bp_new[inst], bv_new[inst] = pr, v
|
||||
|
||||
bp_new = IndexData(list(bp_new.values()), ["base_price"], list(bp_new.keys()))
|
||||
bv_new = IndexData(list(bv_new.values()), ["base_volume"], list(bv_new.keys()))
|
||||
bp_new = IndexData(list(bp_new.values()), list(bp_new.keys()))
|
||||
bv_new = IndexData(list(bv_new.values()), list(bv_new.keys()))
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = IndexData.concat_by_col(bp_all)
|
||||
bv_all = IndexData.concat_by_col(bv_all)
|
||||
|
||||
base_volume = bv_all.sum(axis = 0)
|
||||
base_volume = bv_all.sum(axis=0)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis = 0) / base_volume).to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=0) / base_volume).to_dict())
|
||||
|
||||
def _agg_order_price_advantage(self):
|
||||
def if_empty_func(trade_price):
|
||||
|
||||
Reference in New Issue
Block a user