From d9ad8ff791d8b3d889939229db83bf79fab95123 Mon Sep 17 00:00:00 2001 From: "wangwenxi.handsome" Date: Thu, 26 Aug 2021 12:41:12 +0000 Subject: [PATCH] index_data --- qlib/backtest/exchange.py | 2 +- qlib/backtest/high_performance_ds.py | 417 +++------------------------ qlib/backtest/order.py | 7 +- qlib/backtest/report.py | 27 +- qlib/utils/index_data.py | 410 ++++++++++++++++++++++++++ 5 files changed, 468 insertions(+), 395 deletions(-) create mode 100644 qlib/utils/index_data.py diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index b9b8d087b..21a1d2547 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -39,7 +39,7 @@ class Exchange: close_cost=0.0025, min_cost=5, extra_quote=None, - quote_cls=CN1Min_NumpyQuote, + quote_cls=PandasQuote, **kwargs, ): """__init__ diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 979ca7609..61bf636ae 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. +from builtins import ValueError, isinstance import logging from typing import List, Text, Union, Callable, Iterable, Dict from collections import OrderedDict @@ -11,6 +12,7 @@ import bisect import pandas as pd import numpy as np +from ..utils.index_data import IndexData from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger from ..utils.time import _if_single_data @@ -38,7 +40,7 @@ class BaseQuote: end_time: Union[pd.Timestamp, str], fields: str = None, method: Union[str, Callable] = None, - ) -> Union[None, float, pd.Series, pd.DataFrame, "IndexData"]: + ) -> Union[None, float, "IndexData"]: """get the specific fields of stock data during start time and end_time, and apply method to the data. @@ -65,42 +67,28 @@ class BaseQuote: 85.713585 - 2. Both fields and method are None. It returns pd.Dataframe or np.ndarray. + 2. Both fields and method are None. It returns np.ndarray. print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields=None, method=None)) - 1) pd.Dataframe - $close $volume - datetime - 2010-01-04 86.778313 16162960.0 - 2010-01-05 87.433578 28117442.0 - 2010-01-06 85.713585 23632884.0 - - 2) np.ndarray [ [86.778313, 16162960.0], [87.433578, 28117442.0], [85.713585, 23632884.0], ] - 3. fields is not None, and method is None. It returns pd.Series or IndexData. + 3. fields is not None, and method is None. It returns IndexData. print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method=None)) - 1) pd.Series - 2010-01-04 86.778313 - 2010-01-05 87.433578 - 2010-01-06 85.713585 - - 2) IndexData IndexData([86.778313, 87.433578, 85.713585], [2010-01-04, 2010-01-05, 2010-01-06]) Parameters ---------- - stock_id: Union[str, list] + stock_id: str start_time : Union[pd.Timestamp, str] closed start time for backtest end_time : Union[pd.Timestamp, str] closed end time for backtest - fields : Union[str, List] + fields : str the columns of data to fetch method : Union[str, Callable] the method apply to data. @@ -404,8 +392,8 @@ class BaseOrderIndicator: raise NotImplementedError(f"Please implement the 'get_metric_series' method") - def get_index_data(self, metric): - """get one metric with the format of IndexData + def get_index_data(self, metric) -> IndexData.Series: + """get one metric with the format of IndexData.Series Parameters ---------- @@ -414,8 +402,8 @@ class BaseOrderIndicator: Return ------ - IndexData - one metric with the format of IndexData + IndexData.Series + one metric with the format of IndexData.Series """ raise NotImplementedError(f"Please implement the 'get_index_data' method") @@ -586,12 +574,21 @@ class PandasOrderIndicator(BaseOrderIndicator): else: return tmp_metric + def get_index_data(self, metric): + if metric in self.data: + return IndexData.Series(self.data[metric].metric) + else: + return IndexData.Series() + def get_metric_series(self, metric: str) -> Union[pd.Series]: if metric in self.data: return self.data[metric].metric else: return pd.Series() + def to_series(self): + return {k: v.metric for k, v in self.data.items()} + @staticmethod def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None): if isinstance(metrics, str): @@ -602,387 +599,45 @@ class PandasOrderIndicator(BaseOrderIndicator): tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) order_indicator.assign(metric, tmp_metric.metric) - def to_series(self): - return {k: v.metric for k, v in self.data.items()} - - def get_index_data(self, metric): - if metric in self.data: - return IndexData(self.data[metric].values(), list(self.data[metric].index)) - else: - return IndexData([], []) - - -class NumpySingleMetric(SingleMetric): - def __init__(self, metric: np.ndarray): - self.metric = metric - - def __len__(self): - return len(self.metric) - - def sum(self): - return np.nansum(self.metric) - - def mean(self): - return np.nanmean(self.metric) - - def count(self): - return len(self.metric[~np.isnan(self.metric)]) - - def abs(self): - return self.__class__(np.absolute(self.metric)) - - def astype(self, type): - return self.__class__(self.metric.astype(type)) - - @property - def empty(self): - return len(self.metric) == 0 - - def replace(self, replace_dict: dict): - tmp_metric = self.metric.copy() - for num in replace_dict: - tmp_metric[tmp_metric == num] = replace_dict[num] - return self.__class__(tmp_metric) - - def apply(self, func: Callable): - tmp_metric = self.metric.copy() - for i in range(len(tmp_metric)): - tmp_metric[i] = func(tmp_metric[i]) - return self.__class__(tmp_metric) - class NumpyOrderIndicator(BaseOrderIndicator): - # all metrics - ROW = [ - "amount", - "deal_amount", - "inner_amount", - "trade_price", - "trade_value", - "trade_cost", - "trade_dir", - "ffr", - "pa", - "pos", - "base_price", - "base_volume", - ] - ROW_MAP = dict(zip(ROW, range(len(ROW)))) def __init__(self): - self.row_tag = [0 for tag in range(len(NumpyOrderIndicator.ROW))] - self.data = None + self.data: Dict[str, IndexData.Series] = OrderedDict() def assign(self, col: str, metric: dict): - if col not in NumpyOrderIndicator.ROW: - raise ValueError(f"{col} metric is not supported") - if not isinstance(metric, dict): - raise ValueError(f"metric must be dict") + self.data[col] = IndexData.Series(metric) - # if data is None, init numpy ndarray - if self.data is None: - self.data = np.full((len(NumpyOrderIndicator.ROW), len(metric)), np.NaN) - self.column = list(metric.keys()) - self.column_map = dict(zip(self.column, range(len(self.column)))) - - metric_column = list(metric.keys()) - if self.column != metric_column: - assert len(set(self.column) - set(metric_column)) == 0 - # modify the order - tmp_metric = {} - for column in self.column: - tmp_metric[column] = metric[column] - metric = tmp_metric - - # assign data - self.row_tag[NumpyOrderIndicator.ROW_MAP[col]] = 1 - self.data[NumpyOrderIndicator.ROW_MAP[col]] = list(metric.values()) - - def transfer(self, func: Callable, new_col: str = None) -> Union[None, NumpySingleMetric]: + def transfer(self, func: Callable, new_col: str = None) -> Union[None, IndexData.Series]: func_sig = inspect.signature(func).parameters.keys() - func_kwargs = {} - for sig in func_sig: - if self._if_valid_metric(sig): - func_kwargs[sig] = NumpySingleMetric(self.data[NumpyOrderIndicator.ROW_MAP[sig]]) - else: - self.logger.warning(f"{sig} is not assigned") - func_kwargs[sig] = NumpySingleMetric(np.array([])) + func_kwargs = {sig: self.data[sig] for sig in func_sig} tmp_metric = func(**func_kwargs) if new_col is not None: - self.row_tag[NumpyOrderIndicator.ROW_MAP[new_col]] = 1 - self.data[NumpyOrderIndicator.ROW_MAP[new_col]] = tmp_metric.metric + self.data[new_col] = tmp_metric else: return tmp_metric def get_index_data(self, metric): - if self._if_valid_metric(metric): - return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], self.column) + if metric in self.data: + return self.data[metric] else: - return IndexData([], []) + return IndexData.Series() def get_metric_series(self, metric: str) -> Union[pd.Series]: - if self._if_valid_metric(metric): - return pd.Series(self.data[NumpyOrderIndicator.ROW_MAP[metric]], index=self.column) - else: - return pd.Series() + return self.data[metric].to_pd_series() def to_series(self) -> Dict[str, pd.Series]: tmp_metric_dict = {} - for metric in NumpyOrderIndicator.ROW: + for metric in self.data: tmp_metric_dict[metric] = self.get_metric_series(metric) return tmp_metric_dict - def _if_valid_metric(self, metric): - if metric in NumpyOrderIndicator.ROW and self.row_tag[NumpyOrderIndicator.ROW_MAP[metric]] == 1: - return True - else: - return False - @staticmethod - def sum_all_indicators( - order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None - ) -> Dict[str, NumpySingleMetric]: - # metrics is all metrics to add - # metrics_id means the index in the NumpyOrderIndicator.ROW for metrics. + def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): if isinstance(metrics, str): metrics = [metrics] - metrics_id = [NumpyOrderIndicator.ROW_MAP[metric] for metric in metrics] - - # get all stock_id and all metric data - stocks = set() - indicator_metrics = [] - for indicator in indicators: - stocks = stocks | set(indicator.column) - indicator_metrics.append(indicator.data[metrics_id, :].copy()) - stocks = list(stocks) - stocks.sort() - stocks_map = dict(zip(stocks, range(len(stocks)))) - - # fill value - if fill_value is not None: - base_metrics = fill_value * np.ones((len(metrics), len(stocks))) - for i in range(len(indicators)): - tmp_metrics = base_metrics.copy() - stocks_index = [stocks_map[stock] for stock in indicators[i].column] - tmp_metrics[:, stocks_index] = indicator_metrics[i] - indicator_metrics[i] = tmp_metrics - else: - raise ValueError(f"fill value can not be None in NumpyOrderIndicator") - - # add metric and assign to order_indicator - metric_sum = sum(indicator_metrics) - if order_indicator.data is not None: - raise ValueError(f"this function must assign to an empty order indicator") - order_indicator.data = np.zeros((len(NumpyOrderIndicator.ROW), len(stocks))) - order_indicator.column = stocks - order_indicator.column_map = dict(zip(stocks, range(len(stocks)))) - for i in range(len(metrics)): - order_indicator.row_tag[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = 1 - order_indicator.data[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = metric_sum[i] - - -class IndexData: - def __init__(self, data, index): - """A data structure of index and numpy data. - - Parameters - ---------- - data : np.ndarray - the dim of data must be 1 or 2. - different functions have dimensional limitations - index : list - the index of data. - """ - if isinstance(data, list): - self.data = np.array(data) - elif isinstance(data, np.ndarray): - self.data = data - else: - raise ValueError(f"data must be list or np.ndarray") - self.ndim = self.data.ndim - - assert isinstance(index, list) - self.index = index - self.index_map = dict(zip(self.index, range(len(self.index)))) - - def reindex(self, new_index): - """reindex data and fill the missing value with np.NaN. - just for 1-dim data. - - Parameters - ---------- - new_index : list - new index - - Returns - ------- - IndexData - reindex data - """ - assert self.ndim == 1 - tmp_data = np.full(len(new_index), np.NaN) - for index_id, index in enumerate(new_index): - if index in self.index: - tmp_data[index_id] = self.data[self.index_map[index]] - return IndexData(tmp_data, list(new_index)) - - def to_dict(self): - """convert IndexData to dict. - just for 1-dim data. - - Returns - ------- - dict - data with the dict format. - """ - assert self.ndim == 1 - return dict(zip(self.index, self.data.tolist())) - - def sum(self, axis=None): - """get the sum of data. - - Parameters - ---------- - axis : 0 or None, optional - which axis to sum, by default None - - Returns - ------- - Union[float, IndexData] - if axis is None, it sums all data, return float. - if axis == 1, it sums by row, return IndexData. - """ - if axis is None: - return np.nansum(self.data) - if axis == 0: - assert self.ndim == 2 - tmp_data = np.nansum(self.data, axis=0) - return IndexData(tmp_data, self.index) - else: - raise NotImplementedError(f"axis must be 0 or None") - - def __mul__(self, other): - """multiply with another IndexData. - - Returns - ------- - IndexData - """ - if isinstance(other, IndexData): - assert self.ndim == other.ndim - assert self.index == other.index - assert len(self.data) == len(other.data) - return IndexData(self.data * other.data, self.index) - else: - return NotImplemented - - def __truediv__(self, other): - """divide with another IndexData. - - Returns - ------- - IndexData - """ - if isinstance(other, IndexData): - assert self.ndim == other.ndim - assert self.index == other.index - assert len(self.data) == len(other.data) - return IndexData(self.data / other.data, self.index) - else: - return NotImplemented - - def __len__(self): - """the length of the data. - - Returns - ------- - int - the length of the data. - """ - return len(self.index) - - def __getitem__(self, bool_list: "IndexData"): - """get IndexData by a bool_list which has the same shape of self.data. - just for 1-dim data. - - Parameters - ---------- - bool_list : Union[list, np.ndarray] - a bool_list which has the same shape of self.data. such as array([True, False, True]). - True means the data of the position is reserved. False is not. - - Returns - ------- - IndexData - new IndexData. - """ - assert self.ndim == 1 - assert isinstance(bool_list, IndexData) - new_data = self.data[bool_list.data] - new_index = list(np.array(self.index)[bool_list.data]) - return IndexData(new_data, new_index) - - def __gt__(self, other): - if isinstance(other, (int, float)): - return IndexData(self.data > other, self.index) - elif isinstance(other, IndexData): - return IndexData(self.data > other.data, self.index) - else: - return NotImplemented - - def __lt__(self, other): - if isinstance(other, (int, float)): - return IndexData(self.data < other, self.index) - elif isinstance(other, IndexData): - return IndexData(self.data < other.data, self.index) - else: - return NotImplemented - - def __invert__(self): - return IndexData(~self.data, self.index) - - @staticmethod - def concat_by_index(index_data_list): - """concat all IndexData by index. - just for 1-dim data. - - Parameters - ---------- - index_data_list : List[IndexData] - the list of all IndexData to concat. - - Returns - ------- - IndexData - the IndexData with ndim == 2 - """ - # get all index and row - all_index = set() - for index_data in index_data_list: - all_index = all_index | set(index_data.index) - all_index = list(all_index) - all_index.sort() - all_index_map = dict(zip(all_index, range(len(all_index)))) - - # concat all - tmp_data = np.full((len(index_data_list), len(all_index)), np.NaN) - for data_id, index_data in enumerate(index_data_list): - assert index_data.ndim == 1 - now_data_map = [all_index_map[index] for index in index_data.index] - tmp_data[data_id, now_data_map] = index_data.data - return IndexData(tmp_data, all_index) - - @staticmethod - def ones(index): - """initial the IndexData with index, and fill data with 1. - - Parameters - ---------- - index : list - the index of new data. - - Returns - ------- - IndexData - """ - return IndexData([1 for i in range(len(index))], list(index)) + for metric in metrics: + tmp_metric = IndexData.Series() + for indicator in indicators: + tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) + order_indicator.data[metric] = tmp_metric \ No newline at end of file diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index abd02554a..42af5f24e 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -109,7 +109,7 @@ class Order: return self.direction * 2 - 1 @staticmethod - def parse_dir(direction: Union[str, int, np.integer, OrderDir]) -> OrderDir: + def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> OrderDir: if isinstance(direction, OrderDir): return direction elif isinstance(direction, (int, float, np.integer, np.floating)): @@ -125,6 +125,11 @@ class Order: return OrderDir.BUY else: raise NotImplementedError(f"This type of input is not supported") + elif isinstance(direction, np.ndarray): + direction_array = direction.copy() + direction_array[direction_array > 0] = Order.BUY + direction_array[direction_array <= 0] = Order.SELL + return direction_array else: raise NotImplementedError(f"This type of input is not supported") diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 6272258b7..dbda82dd6 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -16,7 +16,8 @@ from qlib.backtest.exchange import Exchange from qlib.backtest.order import BaseTradeDecision, Order, OrderDir from qlib.backtest.utils import TradeCalendarManager -from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, IndexData +from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator +from ..utils.index_data import IndexData, SingleData from ..data import D from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data @@ -391,9 +392,11 @@ class Indicator: return None, None if isinstance(price_s, pd.Series): - price_s = IndexData(price_s.values, list(price_s.index)) + price_s = IndexData.Series(price_s) elif isinstance(price_s, (int, float, np.floating)): - price_s = IndexData([price_s], [trade_start_time]) + price_s = IndexData.Series(price_s, [trade_start_time]) + elif isinstance(price_s, SingleData): + pass else: raise NotImplementedError(f"This type of input is not supported") @@ -405,11 +408,11 @@ class Indicator: if agg == "vwap": volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) - if isinstance(volume_s, (int, float)): - volume_s = IndexData([volume_s], [trade_start_time]) + if isinstance(volume_s, (int, float, np.floating)): + volume_s = IndexData.Series(volume_s, [trade_start_time]) volume_s = volume_s.reindex(price_s.index) elif agg == "twap": - volume_s = IndexData.ones(price_s.index) + volume_s = IndexData.Series(1, price_s.index) else: raise NotImplementedError(f"This type of input is not supported") @@ -472,16 +475,16 @@ class Indicator: else: bp_new[inst], bv_new[inst] = pr, v - bp_new = IndexData(list(bp_new.values()), list(bp_new.keys())) - bv_new = IndexData(list(bv_new.values()), list(bv_new.keys())) + bp_new = IndexData.Series(bp_new) + bv_new = IndexData.Series(bv_new) bp_all.append(bp_new) bv_all.append(bv_new) - bp_all = IndexData.concat_by_index(bp_all) - bv_all = IndexData.concat_by_index(bv_all) + bp_all = IndexData.concat(bp_all, axis = 1) + bv_all = IndexData.concat(bv_all, axis = 1) - base_volume = bv_all.sum(axis=0) + base_volume = bv_all.sum(axis = 1) self.order_indicator.assign("base_volume", base_volume.to_dict()) - self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=0) / base_volume).to_dict()) + self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict()) def _agg_order_price_advantage(self): def if_empty_func(trade_price): diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py new file mode 100644 index 000000000..47e657c59 --- /dev/null +++ b/qlib/utils/index_data.py @@ -0,0 +1,410 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import numpy as np +import pandas as pd +from typing import Union, Callable + + +class IndexData: + """This is a simplified version of pandas which is faster based on numpy. + """ + @staticmethod + def Series(data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []): + if isinstance(data, dict): + return SingleData(list(data.values()), list(data.keys())) + elif isinstance(data, pd.Series): + return SingleData(data.values, data.index) + else: + return SingleData(data, index) + + @staticmethod + def DataFrame(data: Union[pd.DataFrame, list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []): + if isinstance(data, pd.DataFrame): + return MultiData(data.values, data.index, data.columns) + else: + return MultiData(data, index, columns) + + @staticmethod + def concat(data_list, axis = 0): + """concat all SingleData by index. + just for 1-dim data. + + Parameters + ---------- + index_data_list : List[SingleData] + the list of all SingleData to concat. + + Returns + ------- + MultiData + the MultiData with ndim == 2 + """ + if axis == 0: + raise NotImplementedError(f"please implement this fuc when axis == 0") + elif axis == 1: + # get all index and row + all_index = set() + for index_data in data_list: + all_index = all_index | set(index_data.index) + all_index = list(all_index) + all_index.sort() + all_index_map = dict(zip(all_index, range(len(all_index)))) + + # concat all + tmp_data = np.full((len(all_index), len(data_list)), np.NaN) + for data_id, index_data in enumerate(data_list): + assert isinstance(index_data, SingleData) + now_data_map = [all_index_map[index] for index in index_data.index] + tmp_data[now_data_map, data_id] = index_data.data + return MultiData(tmp_data, all_index) + else: + raise ValueError(f"axis must be 0 or 1") + + +class BaseData: + """Base data structure of SingleData and MultiData. + """ + def __init__(self): + self.index_columns = self._get_index_columns() + + def _get_index_columns(self): + index_columns = [] + if hasattr(self, "index"): + index_columns.append(self.index) + if hasattr(self, "columns"): + index_columns.append(self.columns) + return index_columns + + def _align_index(self, other): + """Align index before performing the four arithmetic operations. + """ + raise NotImplementedError(f"please implement _align_index func") + + def __add__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data + other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data + tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __sub__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data - other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data - tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(other - self.data, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data2.data - tmp_data1.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __mul__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data * other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data * tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data / other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data / tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data == other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data == tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data > other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data > tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __lt__(self, other): + if isinstance(other, (int, float, np.floating)): + return self.__class__(self.data < other, *self.index_columns) + elif isinstance(other, self.__class__): + tmp_data1, tmp_data2 = self._align_index(other) + return self.__class__(tmp_data1.data < tmp_data2.data, *tmp_data1.index_columns) + else: + return NotImplemented + + def __invert__(self): + return self.__class__(~self.data, *self.index_columns) + + def abs(self): + """get the abs of data except np.NaN. + """ + tmp_data = np.absolute(self.data) + return self.__class__(tmp_data, *self.index_columns) + + def astype(self, type): + """change the type of data. + """ + tmp_data = self.data.astype(type) + return self.__class__(tmp_data, *self.index_columns) + + def replace(self, to_replace: dict): + assert isinstance(to_replace, dict) + tmp_data = self.data.copy() + for num in to_replace: + if num in tmp_data: + tmp_data[tmp_data == num] = to_replace[num] + return self.__class__(tmp_data, *self.index_columns) + + def apply(self, func: Callable): + """apply a function to data. + """ + tmp_data = func(self.data) + return self.__class__(tmp_data, *self.index_columns) + + def __len__(self): + """the length of the data. + + Returns + ------- + int + the length of the data. + """ + return len(self.data) + + def sum(self, axis=None): + if axis is None: + return np.nansum(self.data) + elif axis == 0: + tmp_data = np.nansum(self.data, axis=0) + return SingleData(tmp_data, self.columns) + elif axis == 1: + tmp_data = np.nansum(self.data, axis=1) + return SingleData(tmp_data, self.index) + else: + raise ValueError(f"axis must be None, 0 or 1") + + def mean(self, axis=None): + if axis is None: + return np.nanmean(self.data) + elif axis == 0: + tmp_data = np.nanmean(self.data, axis=0) + return SingleData(tmp_data, self.columns) + elif axis == 1: + tmp_data = np.nanmean(self.data, axis=1) + return SingleData(tmp_data, self.index) + else: + raise ValueError(f"axis must be None, 0 or 1") + + def count(self): + return len(self.data[~np.isnan(self.data)]) + + @property + def empty(self): + return len(self.data) == 0 + + +class SingleData(BaseData): + def __init__(self, data: Union[int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []): + """A data structure of index and numpy data. + It's used to replace pd.Series due to high-speed. + + Parameters + ---------- + data : Union[int, float, np.floating, list, np.ndarray] + the dim of data must be 1. + index : Union[list, pd.Index] + the index of data. + """ + # data + if isinstance(data, (int, float, np.floating)): + self.data = np.full(len(index), fill_value=data) + elif isinstance(data, list): + self.data = np.array(data) + elif isinstance(data, np.ndarray): + self.data = data + else: + raise ValueError(f"data must be list or np.ndarray") + # data in SingleData must be one dim + assert self.data.ndim == 1 + # replace int with float + if self.data.dtype == np.int: + self.data = self.data.astype(np.float64) + # replace None with np.NaN, because pd.Series does it. + if None in self.data: + self.data[self.data == None] = np.NaN + + # index + if isinstance(index, list): + if index == [] and len(self.data) > 0: + index = list(range(len(self.data))) + self.index = index + elif isinstance(index, pd.Index): + self.index = list(index) + else: + raise ValueError(f"index must be list or pd.Index") + assert len(self.data) == len(self.index) + # if data is not empty, + self.index_map = dict(zip(self.index, range(len(self.index)))) + + super(SingleData, self).__init__() + + def _align_index(self, other): + if self.index == other.index: + return self, other + elif set(self.index) == set(other.index): + return self, other.reindex(self.index) + else: + raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations") + + def reindex(self, index, fill_value=np.NaN): + """reindex data and fill the missing value with np.NaN. + + Parameters + ---------- + new_index : list + new index + + Returns + ------- + SingleData + reindex data + """ + tmp_data = np.full(len(index), fill_value, np.float64) + for index_id, index_item in enumerate(index): + if index_item in self.index: + tmp_data[index_id] = self.data[self.index_map[index_item]] + return SingleData(tmp_data, index) + + def add(self, other, fill_value=0): + common_index = list(set(self.index) | set(other.index)) + tmp_data1 = self.reindex(common_index,fill_value) + tmp_data2 = other.reindex(common_index,fill_value) + return tmp_data1 + tmp_data2 + + def to_dict(self): + """convert SingleData to dict. + + Returns + ------- + dict + data with the dict format. + """ + return dict(zip(self.index, self.data.tolist())) + + def to_frame(self): + """convert SingleData to MultiData. + + Returns + ------- + MultiData + data with the MultiData format. + """ + return MultiData(self.data[:, np.newaxis], self.index) + + def to_pd_series(self): + return pd.Series(self.data, index = self.index) + + def __getitem__(self, index: Union["SingleData", int, str]): + if isinstance(index, int): + return self.data[index] + elif isinstance(index, str): + return self.data[self.index_map[index]] + elif isinstance(index, SingleData): + new_data = self.data[index.data] + new_index = list(np.array(self.index)[index.data]) + return SingleData(new_data, new_index) + else: + raise ValueError(f"index must be SingleData, int, str") + + +class MultiData(BaseData): + def __init__(self, data: Union[list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []): + """A data structure of index and numpy data. + It's used to replace pd.DataFrame due to high-speed. + + Parameters + ---------- + data : Union[list, np.ndarray] + the dim of data must be 2. + index : Union[list, pd.Index] + the index of data. + columns: Union[list, pd.Index] + the columns of data. + """ + # data + if isinstance(data, list): + self.data = np.array(data) + elif isinstance(data, np.ndarray): + self.data = data + else: + raise ValueError(f"data must be list or np.ndarray") + # data in SingleData must be two dim + assert self.data.ndim == 2 + # replace int with float + if self.data.dtype == np.int: + self.data = self.data.astype(np.float64) + # replace None with np.NaN, because pd.DataFrame does it. + if None in self.data: + self.data[self.data == None] = np.NaN + + # index + if isinstance(index, list): + if index == [] and self.data.shape[0] > 0: + index = list(range(self.data.shape[0])) + self.index = index + elif isinstance(index, pd.Index): + self.index = list(index) + else: + raise ValueError(f"index must be list or pd.Index") + assert self.data.shape[0] == len(self.index) + # if data is not empty, + self.index_map = dict(zip(self.index, range(len(self.index)))) + + # columns + if isinstance(columns, list): + if columns == [] and self.data.shape[1] > 0: + columns = list(range(self.data.shape[1])) + self.columns = columns + elif isinstance(columns, pd.Index): + self.columns = list(columns) + else: + raise ValueError(f"columns must be list or pd.Index") + assert self.data.shape[1] == len(self.columns) + # if data is not empty, + self.columns_map = dict(zip(self.columns, range(len(self.columns)))) + + super(MultiData, self).__init__() + + def _align_index(self, other): + if self.index_columns == other.index_columns: + return self, other + else: + raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations") + + def __getitem__(self, col) -> SingleData: + if col not in self.columns: + return SingleData() + else: + return SingleData(self.data[:, self.columns_map[col]], self.index)