1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

index_data

This commit is contained in:
wangwenxi.handsome
2021-08-26 12:41:12 +00:00
committed by you-n-g
parent 13a9b7cea0
commit d9ad8ff791
5 changed files with 468 additions and 395 deletions

View File

@@ -39,7 +39,7 @@ class Exchange:
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=CN1Min_NumpyQuote,
quote_cls=PandasQuote,
**kwargs,
):
"""__init__

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
from builtins import ValueError, isinstance
import logging
from typing import List, Text, Union, Callable, Iterable, Dict
from collections import OrderedDict
@@ -11,6 +12,7 @@ import bisect
import pandas as pd
import numpy as np
from ..utils.index_data import IndexData
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from ..utils.time import _if_single_data
@@ -38,7 +40,7 @@ class BaseQuote:
end_time: Union[pd.Timestamp, str],
fields: str = None,
method: Union[str, Callable] = None,
) -> Union[None, float, pd.Series, pd.DataFrame, "IndexData"]:
) -> Union[None, float, "IndexData"]:
"""get the specific fields of stock data during start time and end_time,
and apply method to the data.
@@ -65,42 +67,28 @@ class BaseQuote:
85.713585
2. Both fields and method are None. It returns pd.Dataframe or np.ndarray.
2. Both fields and method are None. It returns np.ndarray.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields=None, method=None))
1) pd.Dataframe
$close $volume
datetime
2010-01-04 86.778313 16162960.0
2010-01-05 87.433578 28117442.0
2010-01-06 85.713585 23632884.0
2) np.ndarray
[
[86.778313, 16162960.0],
[87.433578, 28117442.0],
[85.713585, 23632884.0],
]
3. fields is not None, and method is None. It returns pd.Series or IndexData.
3. fields is not None, and method is None. It returns IndexData.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method=None))
1) pd.Series
2010-01-04 86.778313
2010-01-05 87.433578
2010-01-06 85.713585
2) IndexData
IndexData([86.778313, 87.433578, 85.713585], [2010-01-04, 2010-01-05, 2010-01-06])
Parameters
----------
stock_id: Union[str, list]
stock_id: str
start_time : Union[pd.Timestamp, str]
closed start time for backtest
end_time : Union[pd.Timestamp, str]
closed end time for backtest
fields : Union[str, List]
fields : str
the columns of data to fetch
method : Union[str, Callable]
the method apply to data.
@@ -404,8 +392,8 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
def get_index_data(self, metric):
"""get one metric with the format of IndexData
def get_index_data(self, metric) -> IndexData.Series:
"""get one metric with the format of IndexData.Series
Parameters
----------
@@ -414,8 +402,8 @@ class BaseOrderIndicator:
Return
------
IndexData
one metric with the format of IndexData
IndexData.Series
one metric with the format of IndexData.Series
"""
raise NotImplementedError(f"Please implement the 'get_index_data' method")
@@ -586,12 +574,21 @@ class PandasOrderIndicator(BaseOrderIndicator):
else:
return tmp_metric
def get_index_data(self, metric):
if metric in self.data:
return IndexData.Series(self.data[metric].metric)
else:
return IndexData.Series()
def get_metric_series(self, metric: str) -> Union[pd.Series]:
if metric in self.data:
return self.data[metric].metric
else:
return pd.Series()
def to_series(self):
return {k: v.metric for k, v in self.data.items()}
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None):
if isinstance(metrics, str):
@@ -602,387 +599,45 @@ class PandasOrderIndicator(BaseOrderIndicator):
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.assign(metric, tmp_metric.metric)
def to_series(self):
return {k: v.metric for k, v in self.data.items()}
def get_index_data(self, metric):
if metric in self.data:
return IndexData(self.data[metric].values(), list(self.data[metric].index))
else:
return IndexData([], [])
class NumpySingleMetric(SingleMetric):
def __init__(self, metric: np.ndarray):
self.metric = metric
def __len__(self):
return len(self.metric)
def sum(self):
return np.nansum(self.metric)
def mean(self):
return np.nanmean(self.metric)
def count(self):
return len(self.metric[~np.isnan(self.metric)])
def abs(self):
return self.__class__(np.absolute(self.metric))
def astype(self, type):
return self.__class__(self.metric.astype(type))
@property
def empty(self):
return len(self.metric) == 0
def replace(self, replace_dict: dict):
tmp_metric = self.metric.copy()
for num in replace_dict:
tmp_metric[tmp_metric == num] = replace_dict[num]
return self.__class__(tmp_metric)
def apply(self, func: Callable):
tmp_metric = self.metric.copy()
for i in range(len(tmp_metric)):
tmp_metric[i] = func(tmp_metric[i])
return self.__class__(tmp_metric)
class NumpyOrderIndicator(BaseOrderIndicator):
# all metrics
ROW = [
"amount",
"deal_amount",
"inner_amount",
"trade_price",
"trade_value",
"trade_cost",
"trade_dir",
"ffr",
"pa",
"pos",
"base_price",
"base_volume",
]
ROW_MAP = dict(zip(ROW, range(len(ROW))))
def __init__(self):
self.row_tag = [0 for tag in range(len(NumpyOrderIndicator.ROW))]
self.data = None
self.data: Dict[str, IndexData.Series] = OrderedDict()
def assign(self, col: str, metric: dict):
if col not in NumpyOrderIndicator.ROW:
raise ValueError(f"{col} metric is not supported")
if not isinstance(metric, dict):
raise ValueError(f"metric must be dict")
self.data[col] = IndexData.Series(metric)
# if data is None, init numpy ndarray
if self.data is None:
self.data = np.full((len(NumpyOrderIndicator.ROW), len(metric)), np.NaN)
self.column = list(metric.keys())
self.column_map = dict(zip(self.column, range(len(self.column))))
metric_column = list(metric.keys())
if self.column != metric_column:
assert len(set(self.column) - set(metric_column)) == 0
# modify the order
tmp_metric = {}
for column in self.column:
tmp_metric[column] = metric[column]
metric = tmp_metric
# assign data
self.row_tag[NumpyOrderIndicator.ROW_MAP[col]] = 1
self.data[NumpyOrderIndicator.ROW_MAP[col]] = list(metric.values())
def transfer(self, func: Callable, new_col: str = None) -> Union[None, NumpySingleMetric]:
def transfer(self, func: Callable, new_col: str = None) -> Union[None, IndexData.Series]:
func_sig = inspect.signature(func).parameters.keys()
func_kwargs = {}
for sig in func_sig:
if self._if_valid_metric(sig):
func_kwargs[sig] = NumpySingleMetric(self.data[NumpyOrderIndicator.ROW_MAP[sig]])
else:
self.logger.warning(f"{sig} is not assigned")
func_kwargs[sig] = NumpySingleMetric(np.array([]))
func_kwargs = {sig: self.data[sig] for sig in func_sig}
tmp_metric = func(**func_kwargs)
if new_col is not None:
self.row_tag[NumpyOrderIndicator.ROW_MAP[new_col]] = 1
self.data[NumpyOrderIndicator.ROW_MAP[new_col]] = tmp_metric.metric
self.data[new_col] = tmp_metric
else:
return tmp_metric
def get_index_data(self, metric):
if self._if_valid_metric(metric):
return IndexData(self.data[NumpyOrderIndicator.ROW_MAP[metric]], self.column)
if metric in self.data:
return self.data[metric]
else:
return IndexData([], [])
return IndexData.Series()
def get_metric_series(self, metric: str) -> Union[pd.Series]:
if self._if_valid_metric(metric):
return pd.Series(self.data[NumpyOrderIndicator.ROW_MAP[metric]], index=self.column)
else:
return pd.Series()
return self.data[metric].to_pd_series()
def to_series(self) -> Dict[str, pd.Series]:
tmp_metric_dict = {}
for metric in NumpyOrderIndicator.ROW:
for metric in self.data:
tmp_metric_dict[metric] = self.get_metric_series(metric)
return tmp_metric_dict
def _if_valid_metric(self, metric):
if metric in NumpyOrderIndicator.ROW and self.row_tag[NumpyOrderIndicator.ROW_MAP[metric]] == 1:
return True
else:
return False
@staticmethod
def sum_all_indicators(
order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=None
) -> Dict[str, NumpySingleMetric]:
# metrics is all metrics to add
# metrics_id means the index in the NumpyOrderIndicator.ROW for metrics.
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
if isinstance(metrics, str):
metrics = [metrics]
metrics_id = [NumpyOrderIndicator.ROW_MAP[metric] for metric in metrics]
# get all stock_id and all metric data
stocks = set()
indicator_metrics = []
for indicator in indicators:
stocks = stocks | set(indicator.column)
indicator_metrics.append(indicator.data[metrics_id, :].copy())
stocks = list(stocks)
stocks.sort()
stocks_map = dict(zip(stocks, range(len(stocks))))
# fill value
if fill_value is not None:
base_metrics = fill_value * np.ones((len(metrics), len(stocks)))
for i in range(len(indicators)):
tmp_metrics = base_metrics.copy()
stocks_index = [stocks_map[stock] for stock in indicators[i].column]
tmp_metrics[:, stocks_index] = indicator_metrics[i]
indicator_metrics[i] = tmp_metrics
else:
raise ValueError(f"fill value can not be None in NumpyOrderIndicator")
# add metric and assign to order_indicator
metric_sum = sum(indicator_metrics)
if order_indicator.data is not None:
raise ValueError(f"this function must assign to an empty order indicator")
order_indicator.data = np.zeros((len(NumpyOrderIndicator.ROW), len(stocks)))
order_indicator.column = stocks
order_indicator.column_map = dict(zip(stocks, range(len(stocks))))
for i in range(len(metrics)):
order_indicator.row_tag[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = 1
order_indicator.data[NumpyOrderIndicator.ROW_MAP[metrics[i]]] = metric_sum[i]
class IndexData:
def __init__(self, data, index):
"""A data structure of index and numpy data.
Parameters
----------
data : np.ndarray
the dim of data must be 1 or 2.
different functions have dimensional limitations
index : list
the index of data.
"""
if isinstance(data, list):
self.data = np.array(data)
elif isinstance(data, np.ndarray):
self.data = data
else:
raise ValueError(f"data must be list or np.ndarray")
self.ndim = self.data.ndim
assert isinstance(index, list)
self.index = index
self.index_map = dict(zip(self.index, range(len(self.index))))
def reindex(self, new_index):
"""reindex data and fill the missing value with np.NaN.
just for 1-dim data.
Parameters
----------
new_index : list
new index
Returns
-------
IndexData
reindex data
"""
assert self.ndim == 1
tmp_data = np.full(len(new_index), np.NaN)
for index_id, index in enumerate(new_index):
if index in self.index:
tmp_data[index_id] = self.data[self.index_map[index]]
return IndexData(tmp_data, list(new_index))
def to_dict(self):
"""convert IndexData to dict.
just for 1-dim data.
Returns
-------
dict
data with the dict format.
"""
assert self.ndim == 1
return dict(zip(self.index, self.data.tolist()))
def sum(self, axis=None):
"""get the sum of data.
Parameters
----------
axis : 0 or None, optional
which axis to sum, by default None
Returns
-------
Union[float, IndexData]
if axis is None, it sums all data, return float.
if axis == 1, it sums by row, return IndexData.
"""
if axis is None:
return np.nansum(self.data)
if axis == 0:
assert self.ndim == 2
tmp_data = np.nansum(self.data, axis=0)
return IndexData(tmp_data, self.index)
else:
raise NotImplementedError(f"axis must be 0 or None")
def __mul__(self, other):
"""multiply with another IndexData.
Returns
-------
IndexData
"""
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.index == other.index
assert len(self.data) == len(other.data)
return IndexData(self.data * other.data, self.index)
else:
return NotImplemented
def __truediv__(self, other):
"""divide with another IndexData.
Returns
-------
IndexData
"""
if isinstance(other, IndexData):
assert self.ndim == other.ndim
assert self.index == other.index
assert len(self.data) == len(other.data)
return IndexData(self.data / other.data, self.index)
else:
return NotImplemented
def __len__(self):
"""the length of the data.
Returns
-------
int
the length of the data.
"""
return len(self.index)
def __getitem__(self, bool_list: "IndexData"):
"""get IndexData by a bool_list which has the same shape of self.data.
just for 1-dim data.
Parameters
----------
bool_list : Union[list, np.ndarray]
a bool_list which has the same shape of self.data. such as array([True, False, True]).
True means the data of the position is reserved. False is not.
Returns
-------
IndexData
new IndexData.
"""
assert self.ndim == 1
assert isinstance(bool_list, IndexData)
new_data = self.data[bool_list.data]
new_index = list(np.array(self.index)[bool_list.data])
return IndexData(new_data, new_index)
def __gt__(self, other):
if isinstance(other, (int, float)):
return IndexData(self.data > other, self.index)
elif isinstance(other, IndexData):
return IndexData(self.data > other.data, self.index)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return IndexData(self.data < other, self.index)
elif isinstance(other, IndexData):
return IndexData(self.data < other.data, self.index)
else:
return NotImplemented
def __invert__(self):
return IndexData(~self.data, self.index)
@staticmethod
def concat_by_index(index_data_list):
"""concat all IndexData by index.
just for 1-dim data.
Parameters
----------
index_data_list : List[IndexData]
the list of all IndexData to concat.
Returns
-------
IndexData
the IndexData with ndim == 2
"""
# get all index and row
all_index = set()
for index_data in index_data_list:
all_index = all_index | set(index_data.index)
all_index = list(all_index)
all_index.sort()
all_index_map = dict(zip(all_index, range(len(all_index))))
# concat all
tmp_data = np.full((len(index_data_list), len(all_index)), np.NaN)
for data_id, index_data in enumerate(index_data_list):
assert index_data.ndim == 1
now_data_map = [all_index_map[index] for index in index_data.index]
tmp_data[data_id, now_data_map] = index_data.data
return IndexData(tmp_data, all_index)
@staticmethod
def ones(index):
"""initial the IndexData with index, and fill data with 1.
Parameters
----------
index : list
the index of new data.
Returns
-------
IndexData
"""
return IndexData([1 for i in range(len(index))], list(index))
for metric in metrics:
tmp_metric = IndexData.Series()
for indicator in indicators:
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.data[metric] = tmp_metric

View File

@@ -109,7 +109,7 @@ class Order:
return self.direction * 2 - 1
@staticmethod
def parse_dir(direction: Union[str, int, np.integer, OrderDir]) -> OrderDir:
def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> OrderDir:
if isinstance(direction, OrderDir):
return direction
elif isinstance(direction, (int, float, np.integer, np.floating)):
@@ -125,6 +125,11 @@ class Order:
return OrderDir.BUY
else:
raise NotImplementedError(f"This type of input is not supported")
elif isinstance(direction, np.ndarray):
direction_array = direction.copy()
direction_array[direction_array > 0] = Order.BUY
direction_array[direction_array <= 0] = Order.SELL
return direction_array
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -16,7 +16,8 @@ from qlib.backtest.exchange import Exchange
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
from qlib.backtest.utils import TradeCalendarManager
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, IndexData
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator
from ..utils.index_data import IndexData, SingleData
from ..data import D
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
@@ -391,9 +392,11 @@ class Indicator:
return None, None
if isinstance(price_s, pd.Series):
price_s = IndexData(price_s.values, list(price_s.index))
price_s = IndexData.Series(price_s)
elif isinstance(price_s, (int, float, np.floating)):
price_s = IndexData([price_s], [trade_start_time])
price_s = IndexData.Series(price_s, [trade_start_time])
elif isinstance(price_s, SingleData):
pass
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -405,11 +408,11 @@ class Indicator:
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float)):
volume_s = IndexData([volume_s], [trade_start_time])
if isinstance(volume_s, (int, float, np.floating)):
volume_s = IndexData.Series(volume_s, [trade_start_time])
volume_s = volume_s.reindex(price_s.index)
elif agg == "twap":
volume_s = IndexData.ones(price_s.index)
volume_s = IndexData.Series(1, price_s.index)
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -472,16 +475,16 @@ class Indicator:
else:
bp_new[inst], bv_new[inst] = pr, v
bp_new = IndexData(list(bp_new.values()), list(bp_new.keys()))
bv_new = IndexData(list(bv_new.values()), list(bv_new.keys()))
bp_new = IndexData.Series(bp_new)
bv_new = IndexData.Series(bv_new)
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = IndexData.concat_by_index(bp_all)
bv_all = IndexData.concat_by_index(bv_all)
bp_all = IndexData.concat(bp_all, axis = 1)
bv_all = IndexData.concat(bv_all, axis = 1)
base_volume = bv_all.sum(axis=0)
base_volume = bv_all.sum(axis = 1)
self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=0) / base_volume).to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
def _agg_order_price_advantage(self):
def if_empty_func(trade_price):

410
qlib/utils/index_data.py Normal file
View File

@@ -0,0 +1,410 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from typing import Union, Callable
class IndexData:
"""This is a simplified version of pandas which is faster based on numpy.
"""
@staticmethod
def Series(data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
if isinstance(data, dict):
return SingleData(list(data.values()), list(data.keys()))
elif isinstance(data, pd.Series):
return SingleData(data.values, data.index)
else:
return SingleData(data, index)
@staticmethod
def DataFrame(data: Union[pd.DataFrame, list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
if isinstance(data, pd.DataFrame):
return MultiData(data.values, data.index, data.columns)
else:
return MultiData(data, index, columns)
@staticmethod
def concat(data_list, axis = 0):
"""concat all SingleData by index.
just for 1-dim data.
Parameters
----------
index_data_list : List[SingleData]
the list of all SingleData to concat.
Returns
-------
MultiData
the MultiData with ndim == 2
"""
if axis == 0:
raise NotImplementedError(f"please implement this fuc when axis == 0")
elif axis == 1:
# get all index and row
all_index = set()
for index_data in data_list:
all_index = all_index | set(index_data.index)
all_index = list(all_index)
all_index.sort()
all_index_map = dict(zip(all_index, range(len(all_index))))
# concat all
tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
for data_id, index_data in enumerate(data_list):
assert isinstance(index_data, SingleData)
now_data_map = [all_index_map[index] for index in index_data.index]
tmp_data[now_data_map, data_id] = index_data.data
return MultiData(tmp_data, all_index)
else:
raise ValueError(f"axis must be 0 or 1")
class BaseData:
"""Base data structure of SingleData and MultiData.
"""
def __init__(self):
self.index_columns = self._get_index_columns()
def _get_index_columns(self):
index_columns = []
if hasattr(self, "index"):
index_columns.append(self.index)
if hasattr(self, "columns"):
index_columns.append(self.columns)
return index_columns
def _align_index(self, other):
"""Align index before performing the four arithmetic operations.
"""
raise NotImplementedError(f"please implement _align_index func")
def __add__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data + other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data + tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data - other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data - tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(other - self.data, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data2.data - tmp_data1.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data * other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data * tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data / other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data / tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data == other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data == tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data > other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data > tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float, np.floating)):
return self.__class__(self.data < other, *self.index_columns)
elif isinstance(other, self.__class__):
tmp_data1, tmp_data2 = self._align_index(other)
return self.__class__(tmp_data1.data < tmp_data2.data, *tmp_data1.index_columns)
else:
return NotImplemented
def __invert__(self):
return self.__class__(~self.data, *self.index_columns)
def abs(self):
"""get the abs of data except np.NaN.
"""
tmp_data = np.absolute(self.data)
return self.__class__(tmp_data, *self.index_columns)
def astype(self, type):
"""change the type of data.
"""
tmp_data = self.data.astype(type)
return self.__class__(tmp_data, *self.index_columns)
def replace(self, to_replace: dict):
assert isinstance(to_replace, dict)
tmp_data = self.data.copy()
for num in to_replace:
if num in tmp_data:
tmp_data[tmp_data == num] = to_replace[num]
return self.__class__(tmp_data, *self.index_columns)
def apply(self, func: Callable):
"""apply a function to data.
"""
tmp_data = func(self.data)
return self.__class__(tmp_data, *self.index_columns)
def __len__(self):
"""the length of the data.
Returns
-------
int
the length of the data.
"""
return len(self.data)
def sum(self, axis=None):
if axis is None:
return np.nansum(self.data)
elif axis == 0:
tmp_data = np.nansum(self.data, axis=0)
return SingleData(tmp_data, self.columns)
elif axis == 1:
tmp_data = np.nansum(self.data, axis=1)
return SingleData(tmp_data, self.index)
else:
raise ValueError(f"axis must be None, 0 or 1")
def mean(self, axis=None):
if axis is None:
return np.nanmean(self.data)
elif axis == 0:
tmp_data = np.nanmean(self.data, axis=0)
return SingleData(tmp_data, self.columns)
elif axis == 1:
tmp_data = np.nanmean(self.data, axis=1)
return SingleData(tmp_data, self.index)
else:
raise ValueError(f"axis must be None, 0 or 1")
def count(self):
return len(self.data[~np.isnan(self.data)])
@property
def empty(self):
return len(self.data) == 0
class SingleData(BaseData):
def __init__(self, data: Union[int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
"""A data structure of index and numpy data.
It's used to replace pd.Series due to high-speed.
Parameters
----------
data : Union[int, float, np.floating, list, np.ndarray]
the dim of data must be 1.
index : Union[list, pd.Index]
the index of data.
"""
# data
if isinstance(data, (int, float, np.floating)):
self.data = np.full(len(index), fill_value=data)
elif isinstance(data, list):
self.data = np.array(data)
elif isinstance(data, np.ndarray):
self.data = data
else:
raise ValueError(f"data must be list or np.ndarray")
# data in SingleData must be one dim
assert self.data.ndim == 1
# replace int with float
if self.data.dtype == np.int:
self.data = self.data.astype(np.float64)
# replace None with np.NaN, because pd.Series does it.
if None in self.data:
self.data[self.data == None] = np.NaN
# index
if isinstance(index, list):
if index == [] and len(self.data) > 0:
index = list(range(len(self.data)))
self.index = index
elif isinstance(index, pd.Index):
self.index = list(index)
else:
raise ValueError(f"index must be list or pd.Index")
assert len(self.data) == len(self.index)
# if data is not empty,
self.index_map = dict(zip(self.index, range(len(self.index))))
super(SingleData, self).__init__()
def _align_index(self, other):
if self.index == other.index:
return self, other
elif set(self.index) == set(other.index):
return self, other.reindex(self.index)
else:
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
def reindex(self, index, fill_value=np.NaN):
"""reindex data and fill the missing value with np.NaN.
Parameters
----------
new_index : list
new index
Returns
-------
SingleData
reindex data
"""
tmp_data = np.full(len(index), fill_value, np.float64)
for index_id, index_item in enumerate(index):
if index_item in self.index:
tmp_data[index_id] = self.data[self.index_map[index_item]]
return SingleData(tmp_data, index)
def add(self, other, fill_value=0):
common_index = list(set(self.index) | set(other.index))
tmp_data1 = self.reindex(common_index,fill_value)
tmp_data2 = other.reindex(common_index,fill_value)
return tmp_data1 + tmp_data2
def to_dict(self):
"""convert SingleData to dict.
Returns
-------
dict
data with the dict format.
"""
return dict(zip(self.index, self.data.tolist()))
def to_frame(self):
"""convert SingleData to MultiData.
Returns
-------
MultiData
data with the MultiData format.
"""
return MultiData(self.data[:, np.newaxis], self.index)
def to_pd_series(self):
return pd.Series(self.data, index = self.index)
def __getitem__(self, index: Union["SingleData", int, str]):
if isinstance(index, int):
return self.data[index]
elif isinstance(index, str):
return self.data[self.index_map[index]]
elif isinstance(index, SingleData):
new_data = self.data[index.data]
new_index = list(np.array(self.index)[index.data])
return SingleData(new_data, new_index)
else:
raise ValueError(f"index must be SingleData, int, str")
class MultiData(BaseData):
def __init__(self, data: Union[list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
"""A data structure of index and numpy data.
It's used to replace pd.DataFrame due to high-speed.
Parameters
----------
data : Union[list, np.ndarray]
the dim of data must be 2.
index : Union[list, pd.Index]
the index of data.
columns: Union[list, pd.Index]
the columns of data.
"""
# data
if isinstance(data, list):
self.data = np.array(data)
elif isinstance(data, np.ndarray):
self.data = data
else:
raise ValueError(f"data must be list or np.ndarray")
# data in SingleData must be two dim
assert self.data.ndim == 2
# replace int with float
if self.data.dtype == np.int:
self.data = self.data.astype(np.float64)
# replace None with np.NaN, because pd.DataFrame does it.
if None in self.data:
self.data[self.data == None] = np.NaN
# index
if isinstance(index, list):
if index == [] and self.data.shape[0] > 0:
index = list(range(self.data.shape[0]))
self.index = index
elif isinstance(index, pd.Index):
self.index = list(index)
else:
raise ValueError(f"index must be list or pd.Index")
assert self.data.shape[0] == len(self.index)
# if data is not empty,
self.index_map = dict(zip(self.index, range(len(self.index))))
# columns
if isinstance(columns, list):
if columns == [] and self.data.shape[1] > 0:
columns = list(range(self.data.shape[1]))
self.columns = columns
elif isinstance(columns, pd.Index):
self.columns = list(columns)
else:
raise ValueError(f"columns must be list or pd.Index")
assert self.data.shape[1] == len(self.columns)
# if data is not empty,
self.columns_map = dict(zip(self.columns, range(len(self.columns))))
super(MultiData, self).__init__()
def _align_index(self, other):
if self.index_columns == other.index_columns:
return self, other
else:
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
def __getitem__(self, col) -> SingleData:
if col not in self.columns:
return SingleData()
else:
return SingleData(self.data[:, self.columns_map[col]], self.index)