mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
get_base_info
This commit is contained in:
committed by
you-n-g
parent
f7d7f1a223
commit
16b954866f
@@ -200,7 +200,12 @@ class NumpyQuote(BaseQuote):
|
||||
# it used for check if data is None
|
||||
if fields is None:
|
||||
return self.data[stock_id][start_id:end_id]
|
||||
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
|
||||
elif method is None:
|
||||
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
|
||||
stock_dates = self.dates_list[stock_id][start_id:end_id].to_list()
|
||||
return IndexData(stock_data, [stock_id], stock_dates)
|
||||
else:
|
||||
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
|
||||
|
||||
# result lru
|
||||
if len(self.muti_lru) >= self.max_lru_len:
|
||||
@@ -705,20 +710,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
self.row_tag = [0 for tag in range(len(NumpyOrderIndicator.ROW))]
|
||||
self.data = None
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, np.ndarray, pd.Series]):
|
||||
def assign(self, col: str, metric: dict):
|
||||
if col not in NumpyOrderIndicator.ROW:
|
||||
raise ValueError(f"{col} metric is not supoorted")
|
||||
if not isinstance(metric, (dict, np.ndarray, pd.Series)):
|
||||
raise ValueError(f"metric must be dict, pd.Series or np.ndarray")
|
||||
if isinstance(metric, (pd.Series, np.ndarray)) and self.data is None:
|
||||
raise ValueError(f"data can not be None when metric is np.ndarray or pd.Series")
|
||||
if not isinstance(metric, dict):
|
||||
raise ValueError(f"metric must be dict")
|
||||
|
||||
# if data is None, init numpy ndarray
|
||||
if self.data is None:
|
||||
self.data = np.zeros((len(NumpyOrderIndicator.ROW), len(metric)))
|
||||
self.column = list(metric.keys())
|
||||
self.column_map = dict(zip(self.column, range(len(self.column))))
|
||||
|
||||
|
||||
metric_column = list(metric.keys())
|
||||
if self.column != metric_column:
|
||||
assert len(set(self.column) - set(metric_column)) == 0
|
||||
@@ -730,12 +733,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
# assign data
|
||||
self.row_tag[NumpyOrderIndicator.ROW_MAP[col]] = 1
|
||||
if isinstance(metric, dict):
|
||||
self.data[NumpyOrderIndicator.ROW_MAP[col]] = list(metric.values())
|
||||
elif isinstance(metric, np.ndarray):
|
||||
self.data[NumpyOrderIndicator.ROW_MAP[col]] = metric
|
||||
elif isinstance(metric, pd.Series):
|
||||
self.data[NumpyOrderIndicator.ROW_MAP[col]] = metric.values
|
||||
self.data[NumpyOrderIndicator.ROW_MAP[col]] = list(metric.values())
|
||||
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Union[None, NumpySingleMetric]:
|
||||
func_sig = inspect.signature(func).parameters.keys()
|
||||
@@ -753,6 +751,12 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
else:
|
||||
return tmp_metric
|
||||
|
||||
def get_index_data(self, metric):
|
||||
if self._if_valid_metric(metric):
|
||||
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], [metric], self.column)
|
||||
else:
|
||||
return IndexData([], [], [])
|
||||
|
||||
def get_metric_series(self, metric: str) -> Union[pd.Series]:
|
||||
if self._if_valid_metric(metric):
|
||||
return pd.Series(self.data[NumpyOrderIndicator.ROW_MAP[metric]], index=self.column)
|
||||
@@ -788,6 +792,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
stocks = stocks | set(indicator.column)
|
||||
indicator_metrics.append(indicator.data[metrics_id, :].copy())
|
||||
stocks = list(stocks)
|
||||
stocks.sort()
|
||||
stocks_map = dict(zip(stocks, range(len(stocks))))
|
||||
|
||||
# fill value
|
||||
@@ -811,3 +816,110 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
for i in range(len(metrics)):
|
||||
cls.row_tag[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = 1
|
||||
cls.data[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = metric_sum[i]
|
||||
|
||||
|
||||
class IndexData:
|
||||
def __init__(self, data, row, column):
|
||||
if isinstance(data, list):
|
||||
self.data = np.array([data])
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.ndim == 1:
|
||||
self.data = data[np.newaxis, :]
|
||||
elif data.ndim == 2:
|
||||
self.data = data
|
||||
else:
|
||||
raise ValueError(f"the dimension of data must <= 2")
|
||||
else:
|
||||
raise ValueError(f"data must be list or np.ndarray")
|
||||
self.data = data
|
||||
|
||||
assert isinstance(row, list)
|
||||
self.row = row
|
||||
self.row_map = dict(zip(self.row, range(len(self.row))))
|
||||
assert isinstance(column, list)
|
||||
self.col = column
|
||||
self.col_map = dict(zip(self.col, range(len(self.col))))
|
||||
|
||||
def reindex(self, new_column):
|
||||
tmp_data = self.data.copy()
|
||||
for row_id, row in enumerate(self.row):
|
||||
for col_id, col in new_column:
|
||||
if col in self.col:
|
||||
tmp_data[row_id, col_id] = self.data[row_id, self.row_map[col]]
|
||||
else:
|
||||
tmp_data[row_id, col_id] = np.NaN
|
||||
return IndexData(tmp_data, self.row, list(new_column))
|
||||
|
||||
def to_dict(self):
|
||||
assert len(self.row) == 1
|
||||
if self.data.size == 0:
|
||||
return {col: np.NaN for col in self.col}
|
||||
else:
|
||||
return dict(zip(self.col, self.data[0, :].tolist()))
|
||||
|
||||
@staticmethod
|
||||
def concat_by_col(index_data_list):
|
||||
# get all col and row
|
||||
all_col = set()
|
||||
all_row = []
|
||||
for index_data in index_data_list:
|
||||
all_col = all_col | set(index_data.col)
|
||||
all_row.append(index_data.row[0])
|
||||
all_col = list(all_col)
|
||||
all_col.sort()
|
||||
all_col_map = dict(zip(all_col, range(len(all_col))))
|
||||
|
||||
# concat all
|
||||
tmp_data = np.full((len(index_data_list), len(all_col)), np.NaN)
|
||||
for data_id, index_data in enumerate(index_data_list):
|
||||
now_data_map = [all_col_map[col] for col in index_data.col]
|
||||
tmp_data[data_id, now_data_map] = index_data.data
|
||||
return IndexData(tmp_data, all_row, all_col)
|
||||
|
||||
def sum(self, axis = None):
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
if axis == 0:
|
||||
tmp_data = np.nansum(self.data, axis=0)
|
||||
return IndexData(tmp_data, [self.row[0]], self.col)
|
||||
else:
|
||||
raise NotImplementedError(f"axis must be 0 or None")
|
||||
|
||||
def keep_positive(self, limit = 1e-08):
|
||||
assert len(self.row) == 1
|
||||
new_col = []
|
||||
new_data = []
|
||||
for col_id, col in enumerate(self.col):
|
||||
if self.data[0: col_id] < 1e-08:
|
||||
continue
|
||||
else:
|
||||
new_col.append(col)
|
||||
new_data.append(self.data[0: col_id])
|
||||
return IndexData(new_data, self.row, new_col)
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert len(self.row) == len(other.row)
|
||||
assert self.col == other.col
|
||||
return IndexData(self.data * other.data, ["mul"], self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, IndexData):
|
||||
assert len(self.row) == len(other.row)
|
||||
assert self.col == other.col
|
||||
return IndexData(self.data / other.data, ["div"], self.col)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __len__(self):
|
||||
return len(self.col)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator
|
||||
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, IndexData
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
@@ -391,23 +391,26 @@ class Indicator:
|
||||
return None, None
|
||||
|
||||
if isinstance(price_s, (int, float)):
|
||||
price_s = pd.Series(price_s, index=[trade_start_time])
|
||||
price_s = IndexData([price_s], [inst], [trade_start_time])
|
||||
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values.
|
||||
# remove zero and negative values.
|
||||
price_s = price_s.keep_positive(1e-08)
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
volume_s = volume_s.reindex(price_s.index)
|
||||
if isinstance(volume_s, (int, float)):
|
||||
volume_s = IndexData([volume_s], [inst], [trade_start_time])
|
||||
volume_s = volume_s.reindex(price_s.col)
|
||||
elif agg == "twap":
|
||||
volume_s = pd.Series(1, index=price_s.index)
|
||||
volume_s = IndexData([1 for i in range(price_s.col)], [inst], price_s.col)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
base_volume = volume_s.sum().item()
|
||||
base_price = ((price_s * volume_s).sum() / base_volume).item()
|
||||
base_volume = volume_s.sum()
|
||||
base_price = (price_s * volume_s).sum() / base_volume
|
||||
|
||||
return base_price, base_volume
|
||||
|
||||
@@ -441,15 +444,15 @@ class Indicator:
|
||||
"""
|
||||
|
||||
# TODO: I think there are potentials to be optimized
|
||||
trade_dir = self.order_indicator.get_metric_series("trade_dir")
|
||||
trade_dir = self.order_indicator.get_index_data("trade_dir")
|
||||
if len(trade_dir) > 0:
|
||||
bp_all, bv_all = [], []
|
||||
# <step, inst, (base_volume | base_price)>
|
||||
for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
|
||||
bp_s = oi.get_metric_series("base_price").reindex(trade_dir.index)
|
||||
bv_s = oi.get_metric_series("base_volume").reindex(trade_dir.index)
|
||||
bp_s = oi.get_index_data("base_price").reindex(trade_dir.col)
|
||||
bv_s = oi.get_index_data("base_volume").reindex(trade_dir.col)
|
||||
bp_new, bv_new = {}, {}
|
||||
for pr, v, (inst, direction) in zip(bp_s.values, bv_s.values, trade_dir.items()):
|
||||
for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.col, trade_dir.data)):
|
||||
if np.isnan(pr):
|
||||
bp_tmp, bv_tmp = self._get_base_vol_pri(
|
||||
inst,
|
||||
@@ -465,15 +468,16 @@ class Indicator:
|
||||
else:
|
||||
bp_new[inst], bv_new[inst] = pr, v
|
||||
|
||||
bp_new, bv_new = pd.Series(bp_new), pd.Series(bv_new)
|
||||
bp_new = IndexData(list(bp_new.values()), ["base_price"], list(bp_new.keys()))
|
||||
bv_new = IndexData(list(bv_new.values()), ["base_volume"], list(bv_new.keys()))
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = pd.concat(bp_all, axis=1)
|
||||
bv_all = pd.concat(bv_all, axis=1)
|
||||
bp_all = IndexData.concat_by_col(bp_all)
|
||||
bv_all = IndexData.concat_by_col(bv_all)
|
||||
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume)
|
||||
self.order_indicator.assign("base_price", (bp_all * bv_all).sum(axis=1) / base_volume)
|
||||
base_volume = bv_all.sum(axis = 0)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis = 0) / base_volume).to_dict())
|
||||
|
||||
def _agg_order_price_advantage(self):
|
||||
def if_empty_func(trade_price):
|
||||
@@ -592,7 +596,7 @@ class Indicator:
|
||||
)
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = False):
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
if raw:
|
||||
return self.order_indicator
|
||||
return self.order_indicator.to_series()
|
||||
|
||||
Reference in New Issue
Block a user