diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index e5534dfcd..38488b1f7 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -203,7 +203,7 @@ class NumpyQuote(BaseQuote): elif method is None: stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]] stock_dates = self.dates_list[stock_id][start_id:end_id].to_list() - return IndexData(stock_data, [stock_id], stock_dates) + return IndexData(stock_data, stock_dates) else: agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method) @@ -721,7 +721,7 @@ class NumpyOrderIndicator(BaseOrderIndicator): self.data = np.zeros((len(NumpyOrderIndicator.ROW), len(metric))) self.column = list(metric.keys()) self.column_map = dict(zip(self.column, range(len(self.column)))) - + metric_column = list(metric.keys()) if self.column != metric_column: assert len(set(self.column) - set(metric_column)) == 0 @@ -753,9 +753,9 @@ class NumpyOrderIndicator(BaseOrderIndicator): def get_index_data(self, metric): if self._if_valid_metric(metric): - return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], [metric], self.column) + return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], self.column) else: - return IndexData([], [], []) + return IndexData([], []) def get_metric_series(self, metric: str) -> Union[pd.Series]: if self._if_valid_metric(metric): @@ -819,52 +819,80 @@ class NumpyOrderIndicator(BaseOrderIndicator): class IndexData: - def __init__(self, data, row, column): + def __init__(self, data, column): if isinstance(data, list): - self.data = np.array([data]) + self.data = np.array(data) elif isinstance(data, np.ndarray): - if data.ndim == 1: - self.data = data[np.newaxis, :] - elif data.ndim == 2: - self.data = data - else: - raise ValueError(f"the dimension of data must <= 2") + self.data = data else: raise ValueError(f"data must be list or np.ndarray") - self.data = data + self.ndim = self.data.ndim - assert isinstance(row, list) - self.row = row - self.row_map = dict(zip(self.row, range(len(self.row)))) assert isinstance(column, list) self.col = column self.col_map = dict(zip(self.col, range(len(self.col)))) def reindex(self, new_column): - tmp_data = self.data.copy() - for row_id, row in enumerate(self.row): - for col_id, col in new_column: - if col in self.col: - tmp_data[row_id, col_id] = self.data[row_id, self.row_map[col]] - else: - tmp_data[row_id, col_id] = np.NaN - return IndexData(tmp_data, self.row, list(new_column)) + assert self.ndim == 1 + tmp_data = np.full(len(new_column), np.NaN) + for col_id, col in enumerate(new_column): + if col in self.col: + tmp_data[col_id] = self.data[self.col_map[col]] + return IndexData(tmp_data, list(new_column)) def to_dict(self): - assert len(self.row) == 1 - if self.data.size == 0: - return {col: np.NaN for col in self.col} + assert self.ndim == 1 + return dict(zip(self.col, self.data.tolist())) + + def keep_positive(self, limit=1e-08): + assert self.ndim == 1 + new_col = [] + new_data = [] + for col_id, col in enumerate(self.col): + if self.data[col_id] < 1e-08: + continue + else: + new_col.append(col) + new_data.append(self.data[col_id]) + return IndexData(new_data, new_col) + + def sum(self, axis=None): + if axis is None: + return np.nansum(self.data) + if axis == 0: + assert self.ndim == 2 + tmp_data = np.nansum(self.data, axis=0) + return IndexData(tmp_data, self.col) else: - return dict(zip(self.col, self.data[0, :].tolist())) + raise NotImplementedError(f"axis must be 0 or None") + + def __mul__(self, other): + if isinstance(other, IndexData): + assert self.ndim == other.ndim + assert self.col == other.col + assert len(self.data) == len(other.data) + return IndexData(self.data * other.data, self.col) + else: + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, IndexData): + assert self.ndim == other.ndim + assert self.col == other.col + assert len(self.data) == len(other.data) + return IndexData(self.data / other.data, self.col) + else: + return NotImplemented + + def __len__(self): + return len(self.col) @staticmethod def concat_by_col(index_data_list): # get all col and row all_col = set() - all_row = [] for index_data in index_data_list: all_col = all_col | set(index_data.col) - all_row.append(index_data.row[0]) all_col = list(all_col) all_col.sort() all_col_map = dict(zip(all_col, range(len(all_col)))) @@ -874,52 +902,4 @@ class IndexData: for data_id, index_data in enumerate(index_data_list): now_data_map = [all_col_map[col] for col in index_data.col] tmp_data[data_id, now_data_map] = index_data.data - return IndexData(tmp_data, all_row, all_col) - - def sum(self, axis = None): - if axis is None: - return np.nansum(self.data) - if axis == 0: - tmp_data = np.nansum(self.data, axis=0) - return IndexData(tmp_data, [self.row[0]], self.col) - else: - raise NotImplementedError(f"axis must be 0 or None") - - def keep_positive(self, limit = 1e-08): - assert len(self.row) == 1 - new_col = [] - new_data = [] - for col_id, col in enumerate(self.col): - if self.data[0: col_id] < 1e-08: - continue - else: - new_col.append(col) - new_data.append(self.data[0: col_id]) - return IndexData(new_data, self.row, new_col) - - def __mul__(self, other): - if isinstance(other, IndexData): - assert len(self.row) == len(other.row) - assert self.col == other.col - return IndexData(self.data * other.data, ["mul"], self.col) - else: - return NotImplemented - - def __truediv__(self, other): - if isinstance(other, IndexData): - assert len(self.row) == len(other.row) - assert self.col == other.col - return IndexData(self.data / other.data, ["div"], self.col) - else: - return NotImplemented - - def __len__(self): - return len(self.col) - - - - - - - - \ No newline at end of file + return IndexData(tmp_data, all_col) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index c59ca4ea2..486c74d8b 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -391,7 +391,7 @@ class Indicator: return None, None if isinstance(price_s, (int, float)): - price_s = IndexData([price_s], [inst], [trade_start_time]) + price_s = IndexData([price_s], [trade_start_time]) # NOTE: there are some zeros in the trading price. These cases are known meaningless # for aligning the previous logic, remove it. @@ -402,16 +402,15 @@ class Indicator: if agg == "vwap": volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) if isinstance(volume_s, (int, float)): - volume_s = IndexData([volume_s], [inst], [trade_start_time]) + volume_s = IndexData([volume_s], [trade_start_time]) volume_s = volume_s.reindex(price_s.col) elif agg == "twap": - volume_s = IndexData([1 for i in range(price_s.col)], [inst], price_s.col) + volume_s = IndexData([1 for i in range(len(price_s.col))], price_s.col) else: raise NotImplementedError(f"This type of input is not supported") base_volume = volume_s.sum() base_price = (price_s * volume_s).sum() / base_volume - return base_price, base_volume def _agg_base_price( @@ -451,6 +450,7 @@ class Indicator: for oi, (dec, start, end) in zip(inner_order_indicators, decision_list): bp_s = oi.get_index_data("base_price").reindex(trade_dir.col) bv_s = oi.get_index_data("base_volume").reindex(trade_dir.col) + bp_new, bv_new = {}, {} for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.col, trade_dir.data)): if np.isnan(pr): @@ -468,16 +468,16 @@ class Indicator: else: bp_new[inst], bv_new[inst] = pr, v - bp_new = IndexData(list(bp_new.values()), ["base_price"], list(bp_new.keys())) - bv_new = IndexData(list(bv_new.values()), ["base_volume"], list(bv_new.keys())) + bp_new = IndexData(list(bp_new.values()), list(bp_new.keys())) + bv_new = IndexData(list(bv_new.values()), list(bv_new.keys())) bp_all.append(bp_new) bv_all.append(bv_new) bp_all = IndexData.concat_by_col(bp_all) bv_all = IndexData.concat_by_col(bv_all) - base_volume = bv_all.sum(axis = 0) + base_volume = bv_all.sum(axis=0) self.order_indicator.assign("base_volume", base_volume.to_dict()) - self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis = 0) / base_volume).to_dict()) + self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=0) / base_volume).to_dict()) def _agg_order_price_advantage(self): def if_empty_func(trade_price):