1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

align interface

This commit is contained in:
wangwenxi.handsome
2021-08-19 14:06:33 +00:00
committed by you-n-g
parent be0d9e6a22
commit f111e34bd2
4 changed files with 415 additions and 307 deletions

View File

@@ -21,7 +21,7 @@ from ..config import C, REG_CN
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
from .high_performance_ds import PandasQuote, NumpyQuote
from .high_performance_ds import PandasQuote, CN1Min_NumpyQuote
class Exchange:
@@ -39,7 +39,7 @@ class Exchange:
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=NumpyQuote,
quote_cls=CN1Min_NumpyQuote,
**kwargs,
):
"""__init__

View File

@@ -3,9 +3,7 @@
import logging
from pandas._config.config import is_instance_factory
from qlib.data.base import Feature
from typing import List, Text, Tuple, Union, Callable, Iterable, Dict, ValuesView
from typing import List, Text, Union, Callable, Iterable, Dict
from collections import OrderedDict
import inspect
@@ -15,6 +13,7 @@ import numpy as np
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from ..utils.time import _if_single_data
class BaseQuote:
@@ -34,12 +33,12 @@ class BaseQuote:
def get_data(
self,
stock_id: Union[str, list],
stock_id: str,
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
fields: Union[str, list] = None,
fields: str = None,
method: Union[str, Callable] = None,
) -> Union[None, float, pd.Series, pd.DataFrame]:
) -> Union[None, float, pd.Series, pd.DataFrame, "IndexData"]:
"""get the specific fields of stock data during start time and end_time,
and apply method to the data.
@@ -59,21 +58,40 @@ class BaseQuote:
2010-01-12 2788.688232 164587.937500
2010-01-13 2790.604004 145460.453125
print(get_data(stock_id=["SH600000", "SH600655"], start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
this function is used for three case:
$close $volume
instrument
SH600000 87.433578 28117442.0
SH600655 2699.567383 158193.328125
1. Both fields and method are not None. It returns float.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method="last"))
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
85.713585
$close 87.433578
$volume 28117442.0
2. Both fields and method are None. It returns pd.Dataframe or np.ndarray.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields=None, method=None))
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last"))
1) pd.Dataframe
$close $volume
datetime
2010-01-04 86.778313 16162960.0
2010-01-05 87.433578 28117442.0
2010-01-06 85.713585 23632884.0
87.433578
2) np.ndarray
[
[86.778313, 16162960.0],
[87.433578, 28117442.0],
[85.713585, 23632884.0],
]
3. fields is not None, and method is None. It returns pd.Series or IndexData.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method=None))
1) pd.Series
2010-01-04 86.778313
2010-01-05 87.433578
2010-01-06 85.713585
2) IndexData
IndexData([86.778313, 87.433578, 85.713585], [2010-01-04, 2010-01-05, 2010-01-06])
Parameters
----------
@@ -86,12 +104,12 @@ class BaseQuote:
the columns of data to fetch
method : Union[str, Callable]
the method apply to data.
e.g [None, "last", "all", "sum", "mean", "any", qlib/utils/resam.py/ts_data_last]
e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
Return
----------
Union[None, float, pd.Series, pd.DataFrame]
The resampled DataFrame/Series/value, return None when the resampled data is empty.
Union[None, float, pd.Series, pd.DataFrame, IndexData]
please refer to Example as following.
"""
raise NotImplementedError(f"Please implement the `get_data` method")
@@ -104,7 +122,6 @@ class PandasQuote(BaseQuote):
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
self.data = quote_dict
self.freq = np.timedelta64(1, "m")
def get_all_stock(self):
return self.data.keys()
@@ -117,19 +134,10 @@ class PandasQuote(BaseQuote):
else:
raise ValueError(f"fields must be None, str or list")
def _if_single_data(self, start_time, end_time):
if end_time - start_time < self.freq:
return True
if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
return True
if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
return True
return False
class NumpyQuote(BaseQuote):
class CN1Min_NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame):
"""NumpyQuote
"""CN1Min_NumpyQuote
Parameters
----------
@@ -141,20 +149,20 @@ class NumpyQuote(BaseQuote):
each stock has one two-dimensional np.ndarray to represent data.
self.columns: Dict[str, int]
map column name to column id in self.data.
self.dates: Dict[stock_id, Dict[pd.Timestap, int]]
self.dt2idx: Dict[stock_id, Dict[pd.Timestap, int]]
map timestap to row id in self.data.
self.dates_list: Dict[stock_id, List[pd.Timestap]]
the dates of each stock for searching.
self.idx2dt: Dict[stock_id, List[pd.Timestap]]
the dt2idx of each stock for searching.
"""
super().__init__(quote_df=quote_df)
# init data
columns = quote_df.columns.values
self.columns = dict(zip(columns, range(len(columns))))
self.data, self.dates, self.dates_list = self._to_numpy(quote_df)
self.data, self.dt2idx, self.idx2dt = self._to_numpy(quote_df)
# lru
self.muti_lru = {}
self.multi_lru = {}
self.max_lru_len = 256
def _to_numpy(self, quote_df):
@@ -175,27 +183,32 @@ class NumpyQuote(BaseQuote):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
# check fields
if isinstance(fields, list) and len(fields) > 1:
raise ValueError(f"get_data in CN1Min_NumpyQuote only supports one field")
# check stock id
if stock_id not in self.get_all_stock():
return None
# get single data
if self._if_single_data(start_time, end_time):
if start_time not in self.dates[stock_id]:
# single data is only one piece of data, so it don't need to agg by method.
if _if_single_data(start_time, end_time, np.timedelta64(1, "m")):
if start_time not in self.dt2idx[stock_id]:
return None
if fields is None:
# it used for check if data is None
return self.data[stock_id][self.dates[stock_id][start_time]]
return self.data[stock_id][self.dt2idx[stock_id][start_time]]
else:
return self.data[stock_id][self.dates[stock_id][start_time]][self.columns[fields]]
return self.data[stock_id][self.dt2idx[stock_id][start_time]][self.columns[fields]]
# get muti row data
else:
# check lru
if (stock_id, start_time, end_time, fields, method) in self.muti_lru:
return self.muti_lru[(stock_id, start_time, end_time, fields, method)]
if (stock_id, start_time, end_time, fields, method) in self.multi_lru:
return self.multi_lru[(stock_id, start_time, end_time, fields, method)]
start_id = bisect.bisect_left(self.dates_list[stock_id], start_time)
end_id = bisect.bisect_right(self.dates_list[stock_id], end_time)
start_id = bisect.bisect_left(self.idx2dt[stock_id], start_time)
end_id = bisect.bisect_right(self.idx2dt[stock_id], end_time)
if start_id == end_id:
return None
# it used for check if data is None
@@ -203,59 +216,38 @@ class NumpyQuote(BaseQuote):
return self.data[stock_id][start_id:end_id]
elif method is None:
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
stock_dates = self.dates_list[stock_id][start_id:end_id].to_list()
return IndexData(stock_data, stock_dates)
stock_dt2idx = self.idx2dt[stock_id][start_id:end_id].to_list()
return IndexData(stock_data, stock_dt2idx)
else:
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
# result lru
if len(self.muti_lru) >= self.max_lru_len:
self.muti_lru.clear()
self.muti_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data
if len(self.multi_lru) >= self.max_lru_len:
self.multi_lru.clear()
self.multi_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data
return agg_stock_data
def _agg_data(self, data, method):
"""Agg data by specific method."""
valid_data = data[data != np.array(None)].copy()
if method == "sum":
return data.sum()
if method == "mean":
return data.mean()
if method == "last":
return data[-1]
if method == "all":
return data.all()
if method == "any":
return data.any()
if method == ts_data_last:
valid_data = data[data != np.NaN]
return np.nansum(valid_data)
elif method == "mean":
return np.nanmean(valid_data)
elif method == "last":
return valid_data[-1]
elif method == "all":
return valid_data.all()
elif method == "any":
return valid_data.any()
elif method == ts_data_last:
valid_data = valid_data[valid_data != np.NaN]
if len(valid_data) == 0:
return None
else:
return valid_data[0]
def _if_single_data(self, start_time, end_time):
"""Is there only one piece of data to obtaine.
Parameters
----------
start_time : Union[pd.Timestamp, str]
closed start time for data.
end_time : Union[pd.Timestamp, str]
closed end time for data.
Returns
-------
bool
True means one piece of data to obtaine.
"""
if end_time - start_time < np.timedelta64(1, "m"):
return True
if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
return True
if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
return True
return False
else:
raise ValueError(f"{method} is not supported")
class BaseSingleMetric:
@@ -346,10 +338,13 @@ class BaseOrderIndicator:
structure of PandasOrderIndicator is Dict[str, PandasSingleMetric]. It uses
PandasSingleMetric based on pd.Series to represent each metric.
2. The another way doesn't use BaseSingleMetric to represent each metric. The data
structure of PandasOrderIndicator is a whole matrix. It means you are not neccesary
structure of PandasOrderIndicator is a whole matrix. It means you are not necessary
to inherit the BaseSingleMetric.
"""
def __init__(self):
self.logger = get_module_logger("online operator")
def assign(self, col: str, metric: Union[dict, pd.Series]):
"""assign one metric.
@@ -358,10 +353,17 @@ class BaseOrderIndicator:
col : str
the metric name of one metric.
metric : Union[dict, pd.Series]
the metric data.
one metric with stock_id index, such as deal_amount, ffr, etc.
for example:
SH600068 NaN
SH600079 1.0
SH600266 NaN
...
SZ300692 NaN
SZ300719 NaN,
"""
pass
raise NotImplementedError(f"Please implement the 'assign' method")
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
"""compute new metric with existing metrics.
@@ -383,7 +385,7 @@ class BaseOrderIndicator:
new metric.
"""
pass
raise NotImplementedError(f"Please implement the 'transfer' method")
def get_metric_series(self, metric: str) -> pd.Series:
"""return the single metric with pd.Series format.
@@ -400,16 +402,32 @@ class BaseOrderIndicator:
If there is no metric name in the data, return pd.Series().
"""
pass
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
@staticmethod
def sum_all_indicators(cls, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
"""sum indicators with the same metrics.
and assign to the cls(BaseOrderIndicator).
def get_index_data(self, metric):
"""get one metric with the format of IndexData
Parameters
----------
cls : BaseOrderIndicator
metric : str
the metric name.
Return
------
IndexData
one metric with the format of IndexData
"""
raise NotImplementedError(f"Please implement the 'get_index_data' method")
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
"""sum indicators with the same metrics.
and assign to the order_indicator(BaseOrderIndicator).
Parameters
----------
order_indicator : BaseOrderIndicator
the order indicator to assign.
indicators : List[BaseOrderIndicator]
the list of all inner indicators.
@@ -419,7 +437,7 @@ class BaseOrderIndicator:
fill np.NaN with value. By default None.
"""
pass
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
def to_series(self) -> Dict[Text, pd.Series]:
"""return the metrics as pandas series
@@ -437,7 +455,76 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the `to_series` method")
class PandasSingleMetric:
class SingleMetric(BaseSingleMetric):
def __add__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric + other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric + other.metric)
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric - other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric - other.metric)
else:
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (int, float)):
return self.__class__(other - self.metric)
elif isinstance(other, self.__class__):
return self.__class__(other.metric - self.metric)
else:
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric * other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric * other.metric)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric / other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric / other.metric)
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric == other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric == other.metric)
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric > other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric > other.metric)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return self.__class__(self.metric < other)
elif isinstance(other, self.__class__):
return self.__class__(self.metric < other.metric)
else:
return NotImplemented
def __len__(self):
return len(self.metric)
class PandasSingleMetric(SingleMetric):
"""Each SingleMetric is based on pd.Series."""
def __init__(self, metric: Union[dict, pd.Series]):
@@ -448,73 +535,6 @@ class PandasSingleMetric:
else:
raise ValueError(f"metric must be dict or pd.Series")
def __add__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric + other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric + other.metric)
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric - other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric - other.metric)
else:
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(other - self.metric)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(other.metric - self.metric)
else:
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric * other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric * other.metric)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric / other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric / other.metric)
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric == other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric == other.metric)
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric > other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric > other.metric)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric < other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric < other.metric)
else:
return NotImplemented
def __len__(self):
return len(self.metric)
def sum(self):
return self.metric.sum()
@@ -525,23 +545,23 @@ class PandasSingleMetric:
return self.metric.count()
def abs(self):
return PandasSingleMetric(self.metric.abs())
return self.__class__(self.metric.abs())
def astype(self, type):
return PandasSingleMetric(self.metric.astype(type))
return self.__class__(self.metric.astype(type))
@property
def empty(self):
return self.metric.empty
def add(self, other, fill_value=None):
return PandasSingleMetric(self.metric.add(other.metric, fill_value=fill_value))
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
def replace(self, replace_dict: dict):
return PandasSingleMetric(self.metric.replace(replace_dict))
return self.__class__(self.metric.replace(replace_dict))
def apply(self, func: Callable):
return PandasSingleMetric(self.metric.apply(func))
return self.__class__(self.metric.apply(func))
class PandasOrderIndicator(BaseOrderIndicator):
@@ -573,87 +593,29 @@ class PandasOrderIndicator(BaseOrderIndicator):
return pd.Series()
@staticmethod
def sum_all_indicators(cls, indicators: list, metrics: Union[str, List[str]], fill_value=None):
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None):
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
tmp_metric = PandasSingleMetric({})
for indicator in indicators:
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
cls.assign(metric, tmp_metric.metric)
order_indicator.assign(metric, tmp_metric.metric)
def to_series(self):
return {k: v.metric for k, v in self.data.items()}
def get_index_data(self, metric):
if metric in self.data:
return IndexData(self.data[metric].values(), list(self.data[metric].index))
else:
return IndexData([], [])
class NumpySingleMetric(BaseSingleMetric):
class NumpySingleMetric(SingleMetric):
def __init__(self, metric: np.ndarray):
self.metric = metric
def __add__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric + other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric + other.metric)
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric - other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric - other.metric)
else:
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(other - self.metric)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(other.metric - self.metric)
else:
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric * other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric * other.metric)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric / other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric / other.metric)
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric == other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric == other.metric)
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric > other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric > other.metric)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return NumpySingleMetric(self.metric < other)
elif isinstance(other, NumpySingleMetric):
return NumpySingleMetric(self.metric < other.metric)
else:
return NotImplemented
def __len__(self):
return len(self.metric)
@@ -667,10 +629,10 @@ class NumpySingleMetric(BaseSingleMetric):
return len(self.metric[~np.isnan(self.metric)])
def abs(self):
return NumpySingleMetric(np.absolute(self.metric))
return self.__class__(np.absolute(self.metric))
def astype(self, type):
return NumpySingleMetric(self.metric.astype(type))
return self.__class__(self.metric.astype(type))
@property
def empty(self):
@@ -680,13 +642,13 @@ class NumpySingleMetric(BaseSingleMetric):
tmp_metric = self.metric.copy()
for num in replace_dict:
tmp_metric[tmp_metric == num] = replace_dict[num]
return NumpySingleMetric(tmp_metric)
return self.__class__(tmp_metric)
def apply(self, func: Callable):
tmp_metric = self.metric.copy()
for i in range(len(tmp_metric)):
tmp_metric[i] = func(tmp_metric[i])
return NumpySingleMetric(tmp_metric)
return self.__class__(tmp_metric)
class NumpyOrderIndicator(BaseOrderIndicator):
@@ -713,13 +675,13 @@ class NumpyOrderIndicator(BaseOrderIndicator):
def assign(self, col: str, metric: dict):
if col not in NumpyOrderIndicator.ROW:
raise ValueError(f"{col} metric is not supoorted")
raise ValueError(f"{col} metric is not supported")
if not isinstance(metric, dict):
raise ValueError(f"metric must be dict")
# if data is None, init numpy ndarray
if self.data is None:
self.data = np.zeros((len(NumpyOrderIndicator.ROW), len(metric)))
self.data = np.full((len(NumpyOrderIndicator.ROW), len(metric)), np.NaN)
self.column = list(metric.keys())
self.column_map = dict(zip(self.column, range(len(self.column))))
@@ -743,7 +705,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
if self._if_valid_metric(sig):
func_kwargs[sig] = NumpySingleMetric(self.data[NumpyOrderIndicator.ROW_MAP[sig]])
else:
print(f"{sig} is not assigned")
self.logger.warning(f"{sig} is not assigned")
func_kwargs[sig] = NumpySingleMetric(np.array([]))
tmp_metric = func(**func_kwargs)
if new_col is not None:
@@ -778,7 +740,7 @@ class NumpyOrderIndicator(BaseOrderIndicator):
@staticmethod
def sum_all_indicators(
cls, indicators: list, metrics: Union[str, List[str]], fill_value=None
order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None
) -> Dict[str, NumpySingleMetric]:
# metrics is all metrics to add
# metrics_id means the index in the NumpyOrderIndicator.ROW for metrics.
@@ -800,27 +762,37 @@ class NumpyOrderIndicator(BaseOrderIndicator):
if fill_value is not None:
base_metrics = fill_value * np.ones((len(metrics), len(stocks)))
for i in range(len(indicators)):
tmp_netrics = base_metrics.copy()
tmp_metrics = base_metrics.copy()
stocks_index = [stocks_map[stock] for stock in indicators[i].column]
tmp_netrics[:, stocks_index] = indicator_metrics[i]
indicator_metrics[i] = tmp_netrics
tmp_metrics[:, stocks_index] = indicator_metrics[i]
indicator_metrics[i] = tmp_metrics
else:
raise ValueError(f"fill value can not be None in NumpyOrderIndicator")
# add metric and assign to cls
# add metric and assign to order_indicator
metric_sum = sum(indicator_metrics)
if cls.data is not None:
if order_indicator.data is not None:
raise ValueError(f"this function must assign to an empty order indicator")
cls.data = np.zeros((len(NumpyOrderIndicator.ROW), len(stocks)))
cls.column = stocks
cls.column_map = dict(zip(stocks, range(len(stocks))))
order_indicator.data = np.zeros((len(NumpyOrderIndicator.ROW), len(stocks)))
order_indicator.column = stocks
order_indicator.column_map = dict(zip(stocks, range(len(stocks))))
for i in range(len(metrics)):
cls.row_tag[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = 1
cls.data[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = metric_sum[i]
order_indicator.row_tag[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = 1
order_indicator.data[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = metric_sum[i]
class IndexData:
def __init__(self, data, column):
def __init__(self, data, index):
"""A data structure of index and numpy data.
Parameters
----------
data : np.ndarray
the dim of data must be 1 or 2.
different functions have dimensional limitations
index : list
the index of data.
"""
if isinstance(data, list):
self.data = np.array(data)
elif isinstance(data, np.ndarray):
@@ -829,78 +801,188 @@ class IndexData:
raise ValueError(f"data must be list or np.ndarray")
self.ndim = self.data.ndim
assert isinstance(column, list)
self.col = column
self.col_map = dict(zip(self.col, range(len(self.col))))
assert isinstance(index, list)
self.index = index
self.index_map = dict(zip(self.index, range(len(self.index))))
def reindex(self, new_column):
def reindex(self, new_index):
"""reindex data and fill the missing value with np.NaN.
just for 1-dim data.
Parameters
----------
new_index : list
new index
Returns
-------
IndexData
reindex data
"""
assert self.ndim == 1
tmp_data = np.full(len(new_column), np.NaN)
for col_id, col in enumerate(new_column):
if col in self.col:
tmp_data[col_id] = self.data[self.col_map[col]]
return IndexData(tmp_data, list(new_column))
tmp_data = np.full(len(new_index), np.NaN)
for index_id, index in enumerate(new_index):
if index in self.index:
tmp_data[index_id] = self.data[self.index_map[index]]
return IndexData(tmp_data, list(new_index))
def to_dict(self):
assert self.ndim == 1
return dict(zip(self.col, self.data.tolist()))
"""convert IndexData to dict.
just for 1-dim data.
def keep_positive(self, limit=1e-08):
Returns
-------
dict
data with the dict format.
"""
assert self.ndim == 1
new_col = []
new_data = []
for col_id, col in enumerate(self.col):
if self.data[col_id] < 1e-08:
continue
else:
new_col.append(col)
new_data.append(self.data[col_id])
return IndexData(new_data, new_col)
return dict(zip(self.index, self.data.tolist()))
def sum(self, axis=None):
"""get the sum of data.
Parameters
----------
axis : 0 or None, optional
which axis to sum, by default None
Returns
-------
Union[float, IndexData]
if axis is None, it sums all data, return float.
if axis == 1, it sums by row, return IndexData.
"""
if axis is None:
return np.nansum(self.data)
if axis == 0:
assert self.ndim == 2
tmp_data = np.nansum(self.data, axis=0)
return IndexData(tmp_data, self.col)
return IndexData(tmp_data, self.index)
else:
raise NotImplementedError(f"axis must be 0 or None")
def __mul__(self, other):
"""multiply with another IndexData.
Returns
-------
IndexData
"""
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.col == other.col
assert self.index == other.index
assert len(self.data) == len(other.data)
return IndexData(self.data * other.data, self.col)
return IndexData(self.data * other.data, self.index)
else:
return NotImplemented
def __truediv__(self, other):
"""divide with another IndexData.
Returns
-------
IndexData
"""
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.col == other.col
assert self.index == other.index
assert len(self.data) == len(other.data)
return IndexData(self.data / other.data, self.col)
return IndexData(self.data / other.data, self.index)
else:
return NotImplemented
def __len__(self):
return len(self.col)
"""the length of the data.
Returns
-------
int
the length of the data.
"""
return len(self.index)
def __getitem__(self, bool_list: "IndexData"):
"""get IndexData by a bool_list which has the same shape of self.data.
just for 1-dim data.
Parameters
----------
bool_list : Union[list, np.ndarray]
a bool_list which has the same shape of self.data. such as array([True, False, True]).
True means the data of the position is reserved. False is not.
Returns
-------
IndexData
new IndexData.
"""
assert self.ndim == 1
assert isinstance(bool_list, IndexData)
new_data = self.data[bool_list.data]
new_index = list(np.array(self.index)[bool_list.data])
return IndexData(new_data, new_index)
def __gt__(self, other):
if isinstance(other, (int, float)):
return IndexData(self.data > other, self.index)
elif isinstance(other, IndexData):
return IndexData(self.data > other.data, self.index)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return IndexData(self.data < other, self.index)
elif isinstance(other, IndexData):
return IndexData(self.data < other.data, self.index)
else:
return NotImplemented
def __invert__(self):
return IndexData(~self.data, self.index)
@staticmethod
def concat_by_col(index_data_list):
# get all col and row
all_col = set()
def concat_by_index(index_data_list):
"""concat all IndexData by index.
just for 1-dim data.
Parameters
----------
index_data_list : List[IndexData]
the list of all IndexData to concat.
Returns
-------
IndexData
the IndexData with ndim == 2
"""
# get all index and row
all_index = set()
for index_data in index_data_list:
all_col = all_col | set(index_data.col)
all_col = list(all_col)
all_col.sort()
all_col_map = dict(zip(all_col, range(len(all_col))))
all_index = all_index | set(index_data.index)
all_index = list(all_index)
all_index.sort()
all_index_map = dict(zip(all_index, range(len(all_index))))
# concat all
tmp_data = np.full((len(index_data_list), len(all_col)), np.NaN)
tmp_data = np.full((len(index_data_list), len(all_index)), np.NaN)
for data_id, index_data in enumerate(index_data_list):
now_data_map = [all_col_map[col] for col in index_data.col]
assert index_data.ndim == 1
now_data_map = [all_index_map[index] for index in index_data.index]
tmp_data[data_id, now_data_map] = index_data.data
return IndexData(tmp_data, all_col)
return IndexData(tmp_data, all_index)
@staticmethod
def ones(index):
"""initial the IndexData with index, and fill data with 1.
Parameters
----------
index : list
the index of new data.
Returns
-------
IndexData
"""
return IndexData([1 for i in range(len(index))], list(index))

View File

@@ -390,22 +390,24 @@ class Indicator:
if price_s is None:
return None, None
if isinstance(price_s, pd.Series):
price_s = IndexData(price_s.values, list(price_s.index))
if isinstance(price_s, (int, float)):
price_s = IndexData([price_s], [trade_start_time])
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
# remove zero and negative values.
price_s = price_s.keep_positive(1e-08)
price_s = price_s[~(price_s < 1e-08)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float)):
volume_s = IndexData([volume_s], [trade_start_time])
volume_s = volume_s.reindex(price_s.col)
volume_s = volume_s.reindex(price_s.index)
elif agg == "twap":
volume_s = IndexData([1 for i in range(len(price_s.col))], price_s.col)
volume_s = IndexData.ones(price_s.index)
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -448,11 +450,11 @@ class Indicator:
bp_all, bv_all = [], []
# <step, inst, (base_volume | base_price)>
for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
bp_s = oi.get_index_data("base_price").reindex(trade_dir.col)
bv_s = oi.get_index_data("base_volume").reindex(trade_dir.col)
bp_s = oi.get_index_data("base_price").reindex(trade_dir.index)
bv_s = oi.get_index_data("base_volume").reindex(trade_dir.index)
bp_new, bv_new = {}, {}
for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.col, trade_dir.data)):
for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.index, trade_dir.data)):
if np.isnan(pr):
bp_tmp, bv_tmp = self._get_base_vol_pri(
inst,
@@ -472,8 +474,8 @@ class Indicator:
bv_new = IndexData(list(bv_new.values()), list(bv_new.keys()))
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = IndexData.concat_by_col(bp_all)
bv_all = IndexData.concat_by_col(bv_all)
bp_all = IndexData.concat_by_index(bp_all)
bv_all = IndexData.concat_by_index(bv_all)
base_volume = bv_all.sum(axis=0)
self.order_indicator.assign("base_volume", base_volume.to_dict())

View File

@@ -5,13 +5,13 @@ Time related utils are compiled in this script
"""
import bisect
from datetime import datetime, time, date
from typing import List, Tuple
import re
from numpy import append
import pandas as pd
from qlib.config import C
from typing import List, Tuple, Union
import functools
from typing import Union
import re
import pandas as pd
from qlib.config import C
@functools.lru_cache(maxsize=240)
@@ -38,6 +38,30 @@ def get_min_cal(shift: int = 0) -> List[time]:
return cal
def _if_single_data(start_time, end_time, freq):
"""Is there only one piece of data to obtain.
Parameters
----------
start_time : Union[pd.Timestamp, str]
closed start time for data.
end_time : Union[pd.Timestamp, str]
closed end time for data.
Returns
-------
bool
True means one piece of data to obtaine.
"""
if end_time - start_time < freq:
return True
if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
return True
if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
return True
return False
class Freq:
NORM_FREQ_MONTH = "month"
NORM_FREQ_WEEK = "week"