mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
new high freq struc
This commit is contained in:
committed by
you-n-g
parent
d9ad8ff791
commit
25f54ddaeb
@@ -19,6 +19,7 @@ from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManag
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
# make import more user-friendly by enable `from qlib.backtest import STH`
|
||||
|
||||
|
||||
|
||||
@@ -9,19 +9,16 @@ if TYPE_CHECKING:
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
import logging
|
||||
from typing import List, Tuple, Union, Callable, Iterable
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from .order import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import PandasQuote, CN1Min_NumpyQuote
|
||||
from .high_performance_ds import PandasQuote, CN1min_NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
@@ -39,7 +36,7 @@ class Exchange:
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
extra_quote=None,
|
||||
quote_cls=PandasQuote,
|
||||
quote_cls=CN1min_NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
from builtins import ValueError, isinstance
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
@@ -15,7 +16,7 @@ import numpy as np
|
||||
from ..utils.index_data import IndexData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import _if_single_data
|
||||
from ..utils.time import if_single_data
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
@@ -38,9 +39,9 @@ class BaseQuote:
|
||||
stock_id: str,
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
fields: str = None,
|
||||
method: Union[str, Callable] = None,
|
||||
) -> Union[None, float, "IndexData"]:
|
||||
fields: Union[str, None] = None,
|
||||
method: Union[str, Callable, None] = None,
|
||||
) -> Union[None, Union[int, float, bool], "IndexData"]:
|
||||
"""get the specific fields of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
|
||||
@@ -62,7 +63,7 @@ class BaseQuote:
|
||||
|
||||
this function is used for three case:
|
||||
|
||||
1. Both fields and method are not None. It returns float.
|
||||
1. Both fields and method are not None. It returns int/float/bool.
|
||||
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method="last"))
|
||||
|
||||
85.713585
|
||||
@@ -88,15 +89,15 @@ class BaseQuote:
|
||||
closed start time for backtest
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
fields : str
|
||||
fields : Union[str, None]
|
||||
the columns of data to fetch
|
||||
method : Union[str, Callable]
|
||||
method : Union[str, Callable, None]
|
||||
the method apply to data.
|
||||
e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
|
||||
|
||||
Return
|
||||
----------
|
||||
Union[None, float, pd.Series, pd.DataFrame, IndexData]
|
||||
Union[None, Union[int, float, bool], IndexData]
|
||||
please refer to Example as following.
|
||||
"""
|
||||
|
||||
@@ -115,121 +116,105 @@ class PandasQuote(BaseQuote):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
|
||||
if fields is None and method is not None:
|
||||
raise ValueError(f"method must be None when fields is None")
|
||||
|
||||
if fields is None:
|
||||
return resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
|
||||
elif isinstance(fields, (str, list)):
|
||||
return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
|
||||
stock_data = resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
|
||||
elif isinstance(fields, str):
|
||||
stock_data = resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
|
||||
else:
|
||||
raise ValueError(f"fields must be None, str or list")
|
||||
raise ValueError(f"fields must be None, str")
|
||||
|
||||
if stock_data is None:
|
||||
return None
|
||||
elif isinstance(stock_data, (bool, np.bool_, int, float, np.signedinteger, np.floating)):
|
||||
return stock_data
|
||||
elif isinstance(stock_data, pd.Series):
|
||||
return IndexData.Series(stock_data)
|
||||
elif isinstance(stock_data, pd.DataFrame):
|
||||
return stock_data.values
|
||||
else:
|
||||
raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
|
||||
|
||||
|
||||
class CN1Min_NumpyQuote(BaseQuote):
|
||||
class CN1min_NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame):
|
||||
"""CN1Min_NumpyQuote
|
||||
"""CN1min_NumpyQuote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
quote_df : pd.DataFrame
|
||||
the init dataframe from qlib.
|
||||
|
||||
Variables
|
||||
self.data: Dict[stock_id, np.ndarray]
|
||||
each stock has one two-dimensional np.ndarray to represent data.
|
||||
self.columns: Dict[str, int]
|
||||
map column name to column id in self.data.
|
||||
self.dt2idx: Dict[stock_id, Dict[pd.Timestap, int]]
|
||||
map timestap to row id in self.data.
|
||||
self.idx2dt: Dict[stock_id, List[pd.Timestap]]
|
||||
the dt2idx of each stock for searching.
|
||||
self.data : Dict(stock_id, IndexData.DataFrame)
|
||||
"""
|
||||
|
||||
super().__init__(quote_df=quote_df)
|
||||
# init data
|
||||
columns = quote_df.columns.values
|
||||
self.columns = dict(zip(columns, range(len(columns))))
|
||||
self.data, self.dt2idx, self.idx2dt = self._to_numpy(quote_df)
|
||||
|
||||
# lru
|
||||
self.multi_lru = {}
|
||||
self.max_lru_len = 256
|
||||
|
||||
def _to_numpy(self, quote_df):
|
||||
"""convert dataframe to numpy."""
|
||||
|
||||
quote_dict = {}
|
||||
date_dict = {}
|
||||
date_list = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val.values
|
||||
date_dict[stock_id] = stock_val.index.get_level_values("datetime")
|
||||
date_list[stock_id] = list(date_dict[stock_id])
|
||||
for stock_id in date_dict:
|
||||
date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id]))))
|
||||
return quote_dict, date_dict, date_list
|
||||
quote_dict[stock_id] = IndexData.DataFrame(stock_val.droplevel(level="instrument"))
|
||||
self.data = quote_dict
|
||||
self.freq = np.timedelta64(1, "m")
|
||||
|
||||
def get_all_stock(self):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
|
||||
# check fields
|
||||
if isinstance(fields, list) and len(fields) > 1:
|
||||
raise ValueError(f"get_data in CN1Min_NumpyQuote only supports one field")
|
||||
if fields is None and method is not None:
|
||||
raise ValueError(f"method must be None when fields is None")
|
||||
|
||||
# check stock id
|
||||
if stock_id not in self.get_all_stock():
|
||||
return None
|
||||
|
||||
# get single data
|
||||
# single data is only one piece of data, so it don't need to agg by method.
|
||||
if _if_single_data(start_time, end_time, np.timedelta64(1, "m")):
|
||||
if start_time not in self.dt2idx[stock_id]:
|
||||
# single data
|
||||
# If it don't consider the classification of single data, it will consume a lot of time.
|
||||
if if_single_data(start_time, end_time, self.freq):
|
||||
now_index_map = self.data[stock_id].index_map
|
||||
now_columns_map = self.data[stock_id].columns_map
|
||||
if start_time not in now_index_map:
|
||||
return None
|
||||
if fields is None:
|
||||
# it used for check if data is None
|
||||
return self.data[stock_id][self.dt2idx[stock_id][start_time]]
|
||||
return self.data[stock_id].values[now_index_map[start_time]]
|
||||
else:
|
||||
return self.data[stock_id][self.dt2idx[stock_id][start_time]][self.columns[fields]]
|
||||
# get muti row data
|
||||
return self.data[stock_id].values[now_index_map[start_time], now_columns_map[fields]]
|
||||
|
||||
# multi data
|
||||
else:
|
||||
# check lru
|
||||
if (stock_id, start_time, end_time, fields, method) in self.multi_lru:
|
||||
return self.multi_lru[(stock_id, start_time, end_time, fields, method)]
|
||||
|
||||
start_id = bisect.bisect_left(self.idx2dt[stock_id], start_time)
|
||||
end_id = bisect.bisect_right(self.idx2dt[stock_id], end_time)
|
||||
if start_id == end_id:
|
||||
return None
|
||||
# it used for check if data is None
|
||||
if fields is None:
|
||||
return self.data[stock_id][start_id:end_id]
|
||||
elif method is None:
|
||||
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
|
||||
stock_dt2idx = self.idx2dt[stock_id][start_id:end_id].to_list()
|
||||
return IndexData(stock_data, stock_dt2idx)
|
||||
else:
|
||||
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
|
||||
|
||||
# result lru
|
||||
if len(self.multi_lru) >= self.max_lru_len:
|
||||
self.multi_lru.clear()
|
||||
self.multi_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data
|
||||
return agg_stock_data
|
||||
if fields is None and method is None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
else:
|
||||
return stock_data.values
|
||||
elif fields is not None and method is None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
else:
|
||||
return stock_data
|
||||
elif fields is not None and method is not None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
elif len(stock_data) == 1:
|
||||
return stock_data[0]
|
||||
else:
|
||||
return self._agg_data(stock_data.values, method)
|
||||
|
||||
def _agg_data(self, data, method):
|
||||
"""Agg data by specific method."""
|
||||
valid_data = data[data != np.array(None)].copy()
|
||||
if method == "sum":
|
||||
return np.nansum(valid_data)
|
||||
return np.nansum(data)
|
||||
elif method == "mean":
|
||||
return np.nanmean(valid_data)
|
||||
return np.nanmean(data)
|
||||
elif method == "last":
|
||||
return valid_data[-1]
|
||||
return data[-1]
|
||||
elif method == "all":
|
||||
return valid_data.all()
|
||||
return data.all()
|
||||
elif method == "any":
|
||||
return valid_data.any()
|
||||
return data.any()
|
||||
elif method == ts_data_last:
|
||||
valid_data = valid_data[valid_data != np.NaN]
|
||||
valid_data = data[data != np.NaN]
|
||||
if len(valid_data) == 0:
|
||||
return None
|
||||
else:
|
||||
@@ -412,6 +397,7 @@ class BaseOrderIndicator:
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
||||
"""sum indicators with the same metrics.
|
||||
and assign to the order_indicator(BaseOrderIndicator).
|
||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -601,6 +587,11 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
|
||||
class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
"""
|
||||
The data structure is OrderedDict(str: IndexData.Series).
|
||||
Each IndexData.Series is one metric.
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data: Dict[str, IndexData.Series] = OrderedDict()
|
||||
@@ -640,4 +631,4 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
tmp_metric = IndexData.Series()
|
||||
for indicator in indicators:
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
|
||||
@@ -3,25 +3,19 @@
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
from logging import warning
|
||||
import pathlib
|
||||
from typing import Dict, List, Tuple, Union, Callable
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.core import groupby
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..data import D
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from ..utils.time import Freq
|
||||
from .order import IdxTradeRange
|
||||
|
||||
|
||||
@@ -391,9 +385,7 @@ class Indicator:
|
||||
if price_s is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(price_s, pd.Series):
|
||||
price_s = IndexData.Series(price_s)
|
||||
elif isinstance(price_s, (int, float, np.floating)):
|
||||
if isinstance(price_s, (int, float, np.signedinteger, np.floating)):
|
||||
price_s = IndexData.Series(price_s, [trade_start_time])
|
||||
elif isinstance(price_s, SingleData):
|
||||
pass
|
||||
@@ -479,10 +471,10 @@ class Indicator:
|
||||
bv_new = IndexData.Series(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = IndexData.concat(bp_all, axis = 1)
|
||||
bv_all = IndexData.concat(bv_all, axis = 1)
|
||||
bp_all = IndexData.concat(bp_all, axis=1)
|
||||
bv_all = IndexData.concat(bv_all, axis=1)
|
||||
|
||||
base_volume = bv_all.sum(axis = 1)
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
||||
|
||||
|
||||
@@ -2,16 +2,20 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from typing import Union, Callable
|
||||
import bisect
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Callable
|
||||
|
||||
|
||||
class IndexData:
|
||||
"""This is a simplified version of pandas which is faster based on numpy.
|
||||
"""
|
||||
"""This is a simplified version of pandas which is faster based on numpy."""
|
||||
|
||||
@staticmethod
|
||||
def Series(data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
|
||||
def Series(
|
||||
data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []
|
||||
):
|
||||
if isinstance(data, dict):
|
||||
return SingleData(list(data.values()), list(data.keys()))
|
||||
elif isinstance(data, pd.Series):
|
||||
@@ -20,16 +24,20 @@ class IndexData:
|
||||
return SingleData(data, index)
|
||||
|
||||
@staticmethod
|
||||
def DataFrame(data: Union[pd.DataFrame, list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
|
||||
def DataFrame(
|
||||
data: Union[pd.DataFrame, list, np.ndarray] = [[]],
|
||||
index: Union[list, pd.Index] = [],
|
||||
columns: Union[list, pd.Index] = [],
|
||||
):
|
||||
if isinstance(data, pd.DataFrame):
|
||||
return MultiData(data.values, data.index, data.columns)
|
||||
else:
|
||||
else:
|
||||
return MultiData(data, index, columns)
|
||||
|
||||
@staticmethod
|
||||
def concat(data_list, axis = 0):
|
||||
def concat(data_list, axis=0):
|
||||
"""concat all SingleData by index.
|
||||
just for 1-dim data.
|
||||
TODO: now just for SingleData.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -57,15 +65,15 @@ class IndexData:
|
||||
for data_id, index_data in enumerate(data_list):
|
||||
assert isinstance(index_data, SingleData)
|
||||
now_data_map = [all_index_map[index] for index in index_data.index]
|
||||
tmp_data[now_data_map, data_id] = index_data.data
|
||||
tmp_data[now_data_map, data_id] = index_data.data
|
||||
return MultiData(tmp_data, all_index)
|
||||
else:
|
||||
raise ValueError(f"axis must be 0 or 1")
|
||||
|
||||
|
||||
class BaseData:
|
||||
"""Base data structure of SingleData and MultiData.
|
||||
"""
|
||||
"""Base data structure of SingleData and MultiData."""
|
||||
|
||||
def __init__(self):
|
||||
self.index_columns = self._get_index_columns()
|
||||
|
||||
@@ -78,8 +86,7 @@ class BaseData:
|
||||
return index_columns
|
||||
|
||||
def _align_index(self, other):
|
||||
"""Align index before performing the four arithmetic operations.
|
||||
"""
|
||||
"""Align index before performing the four arithmetic operations."""
|
||||
raise NotImplementedError(f"please implement _align_index func")
|
||||
|
||||
def __add__(self, other):
|
||||
@@ -158,14 +165,12 @@ class BaseData:
|
||||
return self.__class__(~self.data, *self.index_columns)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN.
|
||||
"""
|
||||
"""get the abs of data except np.NaN."""
|
||||
tmp_data = np.absolute(self.data)
|
||||
return self.__class__(tmp_data, *self.index_columns)
|
||||
|
||||
def astype(self, type):
|
||||
"""change the type of data.
|
||||
"""
|
||||
"""change the type of data."""
|
||||
tmp_data = self.data.astype(type)
|
||||
return self.__class__(tmp_data, *self.index_columns)
|
||||
|
||||
@@ -178,8 +183,7 @@ class BaseData:
|
||||
return self.__class__(tmp_data, *self.index_columns)
|
||||
|
||||
def apply(self, func: Callable):
|
||||
"""apply a function to data.
|
||||
"""
|
||||
"""apply a function to data."""
|
||||
tmp_data = func(self.data)
|
||||
return self.__class__(tmp_data, *self.index_columns)
|
||||
|
||||
@@ -224,6 +228,10 @@ class BaseData:
|
||||
def empty(self):
|
||||
return len(self.data) == 0
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self.data
|
||||
|
||||
|
||||
class SingleData(BaseData):
|
||||
def __init__(self, data: Union[int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
|
||||
@@ -239,7 +247,7 @@ class SingleData(BaseData):
|
||||
"""
|
||||
# data
|
||||
if isinstance(data, (int, float, np.floating)):
|
||||
self.data = np.full(len(index), fill_value=data)
|
||||
self.data = np.full(len(index), fill_value=data, dtype=np.float64)
|
||||
elif isinstance(data, list):
|
||||
self.data = np.array(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
@@ -249,12 +257,12 @@ class SingleData(BaseData):
|
||||
# data in SingleData must be one dim
|
||||
assert self.data.ndim == 1
|
||||
# replace int with float
|
||||
if self.data.dtype == np.int:
|
||||
if self.data.dtype == np.signedinteger:
|
||||
self.data = self.data.astype(np.float64)
|
||||
# replace None with np.NaN, because pd.Series does it.
|
||||
if None in self.data:
|
||||
self.data[self.data == None] = np.NaN
|
||||
|
||||
|
||||
# index
|
||||
if isinstance(index, list):
|
||||
if index == [] and len(self.data) > 0:
|
||||
@@ -265,18 +273,20 @@ class SingleData(BaseData):
|
||||
else:
|
||||
raise ValueError(f"index must be list or pd.Index")
|
||||
assert len(self.data) == len(self.index)
|
||||
# if data is not empty,
|
||||
# if data is not empty,
|
||||
self.index_map = dict(zip(self.index, range(len(self.index))))
|
||||
|
||||
super(SingleData, self).__init__()
|
||||
|
||||
def _align_index(self, other):
|
||||
if self.index == other.index:
|
||||
return self, other
|
||||
return self, other
|
||||
elif set(self.index) == set(other.index):
|
||||
return self, other.reindex(self.index)
|
||||
else:
|
||||
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
|
||||
raise ValueError(
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
)
|
||||
|
||||
def reindex(self, index, fill_value=np.NaN):
|
||||
"""reindex data and fill the missing value with np.NaN.
|
||||
@@ -291,7 +301,7 @@ class SingleData(BaseData):
|
||||
SingleData
|
||||
reindex data
|
||||
"""
|
||||
tmp_data = np.full(len(index), fill_value, np.float64)
|
||||
tmp_data = np.full(len(index), fill_value, dtype=np.float64)
|
||||
for index_id, index_item in enumerate(index):
|
||||
if index_item in self.index:
|
||||
tmp_data[index_id] = self.data[self.index_map[index_item]]
|
||||
@@ -299,8 +309,8 @@ class SingleData(BaseData):
|
||||
|
||||
def add(self, other, fill_value=0):
|
||||
common_index = list(set(self.index) | set(other.index))
|
||||
tmp_data1 = self.reindex(common_index,fill_value)
|
||||
tmp_data2 = other.reindex(common_index,fill_value)
|
||||
tmp_data1 = self.reindex(common_index, fill_value)
|
||||
tmp_data2 = other.reindex(common_index, fill_value)
|
||||
return tmp_data1 + tmp_data2
|
||||
|
||||
def to_dict(self):
|
||||
@@ -324,7 +334,7 @@ class SingleData(BaseData):
|
||||
return MultiData(self.data[:, np.newaxis], self.index)
|
||||
|
||||
def to_pd_series(self):
|
||||
return pd.Series(self.data, index = self.index)
|
||||
return pd.Series(self.data, index=self.index)
|
||||
|
||||
def __getitem__(self, index: Union["SingleData", int, str]):
|
||||
if isinstance(index, int):
|
||||
@@ -340,7 +350,12 @@ class SingleData(BaseData):
|
||||
|
||||
|
||||
class MultiData(BaseData):
|
||||
def __init__(self, data: Union[list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[list, np.ndarray] = [[]],
|
||||
index: Union[list, pd.Index] = [],
|
||||
columns: Union[list, pd.Index] = [],
|
||||
):
|
||||
"""A data structure of index and numpy data.
|
||||
It's used to replace pd.DataFrame due to high-speed.
|
||||
|
||||
@@ -363,12 +378,12 @@ class MultiData(BaseData):
|
||||
# data in SingleData must be two dim
|
||||
assert self.data.ndim == 2
|
||||
# replace int with float
|
||||
if self.data.dtype == np.int:
|
||||
if self.data.dtype == np.signedinteger:
|
||||
self.data = self.data.astype(np.float64)
|
||||
# replace None with np.NaN, because pd.DataFrame does it.
|
||||
if None in self.data:
|
||||
self.data[self.data == None] = np.NaN
|
||||
|
||||
|
||||
# index
|
||||
if isinstance(index, list):
|
||||
if index == [] and self.data.shape[0] > 0:
|
||||
@@ -379,7 +394,7 @@ class MultiData(BaseData):
|
||||
else:
|
||||
raise ValueError(f"index must be list or pd.Index")
|
||||
assert self.data.shape[0] == len(self.index)
|
||||
# if data is not empty,
|
||||
# if data is not empty,
|
||||
self.index_map = dict(zip(self.index, range(len(self.index))))
|
||||
|
||||
# columns
|
||||
@@ -392,19 +407,29 @@ class MultiData(BaseData):
|
||||
else:
|
||||
raise ValueError(f"columns must be list or pd.Index")
|
||||
assert self.data.shape[1] == len(self.columns)
|
||||
# if data is not empty,
|
||||
self.columns_map = dict(zip(self.columns, range(len(self.columns))))
|
||||
# if data is not empty,
|
||||
self.columns_map = dict(zip(self.columns, range(len(self.columns))))
|
||||
|
||||
super(MultiData, self).__init__()
|
||||
|
||||
def _align_index(self, other):
|
||||
if self.index_columns == other.index_columns:
|
||||
return self, other
|
||||
return self, other
|
||||
else:
|
||||
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
|
||||
raise ValueError(
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
)
|
||||
|
||||
def __getitem__(self, col) -> SingleData:
|
||||
if col not in self.columns:
|
||||
return SingleData()
|
||||
else:
|
||||
return SingleData(self.data[:, self.columns_map[col]], self.index)
|
||||
|
||||
def loc(self, start, end, col=None):
|
||||
start_id = bisect.bisect_left(self.index, start)
|
||||
end_id = bisect.bisect_right(self.index, end)
|
||||
if col is None:
|
||||
return MultiData(self.data[start_id:end_id], self.index[start_id:end_id], self.columns)
|
||||
else:
|
||||
return SingleData(self.data[start_id:end_id, self.columns_map[col]], self.index[start_id:end_id])
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_min_cal(shift: int = 0) -> List[time]:
|
||||
return cal
|
||||
|
||||
|
||||
def _if_single_data(start_time, end_time, freq):
|
||||
def if_single_data(start_time, end_time, freq):
|
||||
"""Is there only one piece of data to obtain.
|
||||
|
||||
Parameters
|
||||
|
||||
Reference in New Issue
Block a user