diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 53d148280..b4a46614e 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -19,6 +19,7 @@ from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManag from ..utils import init_instance_by_config from ..log import get_module_logger from ..config import C + # make import more user-friendly by enable `from qlib.backtest import STH` diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 21a1d2547..4c726720c 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -9,19 +9,16 @@ if TYPE_CHECKING: from qlib.backtest.position import BasePosition, Position import random -import logging -from typing import List, Tuple, Union, Callable, Iterable - +from typing import List, Tuple, Union import numpy as np import pandas as pd from ..data.data import D -from ..data.dataset.utils import get_level_index from ..config import C, REG_CN from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger from .order import Order, OrderDir, OrderHelper -from .high_performance_ds import PandasQuote, CN1Min_NumpyQuote +from .high_performance_ds import PandasQuote, CN1min_NumpyQuote class Exchange: @@ -39,7 +36,7 @@ class Exchange: close_cost=0.0025, min_cost=5, extra_quote=None, - quote_cls=PandasQuote, + quote_cls=CN1min_NumpyQuote, **kwargs, ): """__init__ diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 61bf636ae..6f38b390a 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -3,6 +3,7 @@ from builtins import ValueError, isinstance +from functools import lru_cache import logging from typing import List, Text, Union, Callable, Iterable, Dict from collections import OrderedDict @@ -15,7 +16,7 @@ import numpy as np from ..utils.index_data import IndexData from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger -from ..utils.time import _if_single_data +from ..utils.time import if_single_data class BaseQuote: @@ -38,9 +39,9 @@ class BaseQuote: stock_id: str, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], - fields: str = None, - method: Union[str, Callable] = None, - ) -> Union[None, float, "IndexData"]: + fields: Union[str, None] = None, + method: Union[str, Callable, None] = None, + ) -> Union[None, Union[int, float, bool], "IndexData"]: """get the specific fields of stock data during start time and end_time, and apply method to the data. @@ -62,7 +63,7 @@ class BaseQuote: this function is used for three case: - 1. Both fields and method are not None. It returns float. + 1. Both fields and method are not None. It returns int/float/bool. print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method="last")) 85.713585 @@ -88,15 +89,15 @@ class BaseQuote: closed start time for backtest end_time : Union[pd.Timestamp, str] closed end time for backtest - fields : str + fields : Union[str, None] the columns of data to fetch - method : Union[str, Callable] + method : Union[str, Callable, None] the method apply to data. e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last] Return ---------- - Union[None, float, pd.Series, pd.DataFrame, IndexData] + Union[None, Union[int, float, bool], IndexData] please refer to Example as following. """ @@ -115,121 +116,105 @@ class PandasQuote(BaseQuote): return self.data.keys() def get_data(self, stock_id, start_time, end_time, fields=None, method=None): + if fields is None and method is not None: + raise ValueError(f"method must be None when fields is None") + if fields is None: - return resam_ts_data(self.data[stock_id], start_time, end_time, method=method) - elif isinstance(fields, (str, list)): - return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method) + stock_data = resam_ts_data(self.data[stock_id], start_time, end_time, method=method) + elif isinstance(fields, str): + stock_data = resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method) else: - raise ValueError(f"fields must be None, str or list") + raise ValueError(f"fields must be None, str") + + if stock_data is None: + return None + elif isinstance(stock_data, (bool, np.bool_, int, float, np.signedinteger, np.floating)): + return stock_data + elif isinstance(stock_data, pd.Series): + return IndexData.Series(stock_data) + elif isinstance(stock_data, pd.DataFrame): + return stock_data.values + else: + raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame") -class CN1Min_NumpyQuote(BaseQuote): +class CN1min_NumpyQuote(BaseQuote): def __init__(self, quote_df: pd.DataFrame): - """CN1Min_NumpyQuote + """CN1min_NumpyQuote Parameters ---------- quote_df : pd.DataFrame the init dataframe from qlib. - - Variables - self.data: Dict[stock_id, np.ndarray] - each stock has one two-dimensional np.ndarray to represent data. - self.columns: Dict[str, int] - map column name to column id in self.data. - self.dt2idx: Dict[stock_id, Dict[pd.Timestap, int]] - map timestap to row id in self.data. - self.idx2dt: Dict[stock_id, List[pd.Timestap]] - the dt2idx of each stock for searching. + self.data : Dict(stock_id, IndexData.DataFrame) """ - super().__init__(quote_df=quote_df) - # init data - columns = quote_df.columns.values - self.columns = dict(zip(columns, range(len(columns)))) - self.data, self.dt2idx, self.idx2dt = self._to_numpy(quote_df) - - # lru - self.multi_lru = {} - self.max_lru_len = 256 - - def _to_numpy(self, quote_df): - """convert dataframe to numpy.""" - quote_dict = {} - date_dict = {} - date_list = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): - quote_dict[stock_id] = stock_val.values - date_dict[stock_id] = stock_val.index.get_level_values("datetime") - date_list[stock_id] = list(date_dict[stock_id]) - for stock_id in date_dict: - date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id])))) - return quote_dict, date_dict, date_list + quote_dict[stock_id] = IndexData.DataFrame(stock_val.droplevel(level="instrument")) + self.data = quote_dict + self.freq = np.timedelta64(1, "m") def get_all_stock(self): return self.data.keys() def get_data(self, stock_id, start_time, end_time, fields=None, method=None): - # check fields - if isinstance(fields, list) and len(fields) > 1: - raise ValueError(f"get_data in CN1Min_NumpyQuote only supports one field") + if fields is None and method is not None: + raise ValueError(f"method must be None when fields is None") # check stock id if stock_id not in self.get_all_stock(): return None - # get single data - # single data is only one piece of data, so it don't need to agg by method. - if _if_single_data(start_time, end_time, np.timedelta64(1, "m")): - if start_time not in self.dt2idx[stock_id]: + # single data + # If it don't consider the classification of single data, it will consume a lot of time. + if if_single_data(start_time, end_time, self.freq): + now_index_map = self.data[stock_id].index_map + now_columns_map = self.data[stock_id].columns_map + if start_time not in now_index_map: return None if fields is None: - # it used for check if data is None - return self.data[stock_id][self.dt2idx[stock_id][start_time]] + return self.data[stock_id].values[now_index_map[start_time]] else: - return self.data[stock_id][self.dt2idx[stock_id][start_time]][self.columns[fields]] - # get muti row data + return self.data[stock_id].values[now_index_map[start_time], now_columns_map[fields]] + + # multi data else: - # check lru - if (stock_id, start_time, end_time, fields, method) in self.multi_lru: - return self.multi_lru[(stock_id, start_time, end_time, fields, method)] - - start_id = bisect.bisect_left(self.idx2dt[stock_id], start_time) - end_id = bisect.bisect_right(self.idx2dt[stock_id], end_time) - if start_id == end_id: - return None - # it used for check if data is None - if fields is None: - return self.data[stock_id][start_id:end_id] - elif method is None: - stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]] - stock_dt2idx = self.idx2dt[stock_id][start_id:end_id].to_list() - return IndexData(stock_data, stock_dt2idx) - else: - agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method) - - # result lru - if len(self.multi_lru) >= self.max_lru_len: - self.multi_lru.clear() - self.multi_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data - return agg_stock_data + if fields is None and method is None: + stock_data = self.data[stock_id].loc(start_time, end_time) + if stock_data.empty: + return None + else: + return stock_data.values + elif fields is not None and method is None: + stock_data = self.data[stock_id].loc(start_time, end_time, fields) + if stock_data.empty: + return None + else: + return stock_data + elif fields is not None and method is not None: + stock_data = self.data[stock_id].loc(start_time, end_time, fields) + if stock_data.empty: + return None + elif len(stock_data) == 1: + return stock_data[0] + else: + return self._agg_data(stock_data.values, method) def _agg_data(self, data, method): """Agg data by specific method.""" - valid_data = data[data != np.array(None)].copy() if method == "sum": - return np.nansum(valid_data) + return np.nansum(data) elif method == "mean": - return np.nanmean(valid_data) + return np.nanmean(data) elif method == "last": - return valid_data[-1] + return data[-1] elif method == "all": - return valid_data.all() + return data.all() elif method == "any": - return valid_data.any() + return data.any() elif method == ts_data_last: - valid_data = valid_data[valid_data != np.NaN] + valid_data = data[data != np.NaN] if len(valid_data) == 0: return None else: @@ -412,6 +397,7 @@ class BaseOrderIndicator: def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None): """sum indicators with the same metrics. and assign to the order_indicator(BaseOrderIndicator). + NOTE: indicators could be a empty list when orders in lower level all fail. Parameters ---------- @@ -601,6 +587,11 @@ class PandasOrderIndicator(BaseOrderIndicator): class NumpyOrderIndicator(BaseOrderIndicator): + """ + The data structure is OrderedDict(str: IndexData.Series). + Each IndexData.Series is one metric. + Str is the name of metric. + """ def __init__(self): self.data: Dict[str, IndexData.Series] = OrderedDict() @@ -640,4 +631,4 @@ class NumpyOrderIndicator(BaseOrderIndicator): tmp_metric = IndexData.Series() for indicator in indicators: tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) - order_indicator.data[metric] = tmp_metric \ No newline at end of file + order_indicator.data[metric] = tmp_metric diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index dbda82dd6..31c3e7b0a 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -3,25 +3,19 @@ from collections import OrderedDict -from logging import warning import pathlib -from typing import Dict, List, Tuple, Union, Callable +from typing import Dict, List, Tuple import numpy as np import pandas as pd -from pandas.core import groupby -from pandas.core.frame import DataFrame from qlib.backtest.exchange import Exchange from qlib.backtest.order import BaseTradeDecision, Order, OrderDir -from qlib.backtest.utils import TradeCalendarManager from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator -from ..utils.index_data import IndexData, SingleData -from ..data import D +from ..utils.index_data import IndexData, SingleData from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data -from ..utils.time import Freq from .order import IdxTradeRange @@ -391,9 +385,7 @@ class Indicator: if price_s is None: return None, None - if isinstance(price_s, pd.Series): - price_s = IndexData.Series(price_s) - elif isinstance(price_s, (int, float, np.floating)): + if isinstance(price_s, (int, float, np.signedinteger, np.floating)): price_s = IndexData.Series(price_s, [trade_start_time]) elif isinstance(price_s, SingleData): pass @@ -479,10 +471,10 @@ class Indicator: bv_new = IndexData.Series(bv_new) bp_all.append(bp_new) bv_all.append(bv_new) - bp_all = IndexData.concat(bp_all, axis = 1) - bv_all = IndexData.concat(bv_all, axis = 1) + bp_all = IndexData.concat(bp_all, axis=1) + bv_all = IndexData.concat(bv_all, axis=1) - base_volume = bv_all.sum(axis = 1) + base_volume = bv_all.sum(axis=1) self.order_indicator.assign("base_volume", base_volume.to_dict()) self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict()) diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 47e657c59..8c0d1874e 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -2,16 +2,20 @@ # Licensed under the MIT License. +from typing import Union, Callable +import bisect + import numpy as np import pandas as pd -from typing import Union, Callable class IndexData: - """This is a simplified version of pandas which is faster based on numpy. - """ + """This is a simplified version of pandas which is faster based on numpy.""" + @staticmethod - def Series(data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []): + def Series( + data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = [] + ): if isinstance(data, dict): return SingleData(list(data.values()), list(data.keys())) elif isinstance(data, pd.Series): @@ -20,16 +24,20 @@ class IndexData: return SingleData(data, index) @staticmethod - def DataFrame(data: Union[pd.DataFrame, list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []): + def DataFrame( + data: Union[pd.DataFrame, list, np.ndarray] = [[]], + index: Union[list, pd.Index] = [], + columns: Union[list, pd.Index] = [], + ): if isinstance(data, pd.DataFrame): return MultiData(data.values, data.index, data.columns) - else: + else: return MultiData(data, index, columns) @staticmethod - def concat(data_list, axis = 0): + def concat(data_list, axis=0): """concat all SingleData by index. - just for 1-dim data. + TODO: now just for SingleData. Parameters ---------- @@ -57,15 +65,15 @@ class IndexData: for data_id, index_data in enumerate(data_list): assert isinstance(index_data, SingleData) now_data_map = [all_index_map[index] for index in index_data.index] - tmp_data[now_data_map, data_id] = index_data.data + tmp_data[now_data_map, data_id] = index_data.data return MultiData(tmp_data, all_index) else: raise ValueError(f"axis must be 0 or 1") class BaseData: - """Base data structure of SingleData and MultiData. - """ + """Base data structure of SingleData and MultiData.""" + def __init__(self): self.index_columns = self._get_index_columns() @@ -78,8 +86,7 @@ class BaseData: return index_columns def _align_index(self, other): - """Align index before performing the four arithmetic operations. - """ + """Align index before performing the four arithmetic operations.""" raise NotImplementedError(f"please implement _align_index func") def __add__(self, other): @@ -158,14 +165,12 @@ class BaseData: return self.__class__(~self.data, *self.index_columns) def abs(self): - """get the abs of data except np.NaN. - """ + """get the abs of data except np.NaN.""" tmp_data = np.absolute(self.data) return self.__class__(tmp_data, *self.index_columns) def astype(self, type): - """change the type of data. - """ + """change the type of data.""" tmp_data = self.data.astype(type) return self.__class__(tmp_data, *self.index_columns) @@ -178,8 +183,7 @@ class BaseData: return self.__class__(tmp_data, *self.index_columns) def apply(self, func: Callable): - """apply a function to data. - """ + """apply a function to data.""" tmp_data = func(self.data) return self.__class__(tmp_data, *self.index_columns) @@ -224,6 +228,10 @@ class BaseData: def empty(self): return len(self.data) == 0 + @property + def values(self): + return self.data + class SingleData(BaseData): def __init__(self, data: Union[int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []): @@ -239,7 +247,7 @@ class SingleData(BaseData): """ # data if isinstance(data, (int, float, np.floating)): - self.data = np.full(len(index), fill_value=data) + self.data = np.full(len(index), fill_value=data, dtype=np.float64) elif isinstance(data, list): self.data = np.array(data) elif isinstance(data, np.ndarray): @@ -249,12 +257,12 @@ class SingleData(BaseData): # data in SingleData must be one dim assert self.data.ndim == 1 # replace int with float - if self.data.dtype == np.int: + if self.data.dtype == np.signedinteger: self.data = self.data.astype(np.float64) # replace None with np.NaN, because pd.Series does it. if None in self.data: self.data[self.data == None] = np.NaN - + # index if isinstance(index, list): if index == [] and len(self.data) > 0: @@ -265,18 +273,20 @@ class SingleData(BaseData): else: raise ValueError(f"index must be list or pd.Index") assert len(self.data) == len(self.index) - # if data is not empty, + # if data is not empty, self.index_map = dict(zip(self.index, range(len(self.index)))) super(SingleData, self).__init__() def _align_index(self, other): if self.index == other.index: - return self, other + return self, other elif set(self.index) == set(other.index): return self, other.reindex(self.index) else: - raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations") + raise ValueError( + f"The indexes of self and other do not meet the requirements of the four arithmetic operations" + ) def reindex(self, index, fill_value=np.NaN): """reindex data and fill the missing value with np.NaN. @@ -291,7 +301,7 @@ class SingleData(BaseData): SingleData reindex data """ - tmp_data = np.full(len(index), fill_value, np.float64) + tmp_data = np.full(len(index), fill_value, dtype=np.float64) for index_id, index_item in enumerate(index): if index_item in self.index: tmp_data[index_id] = self.data[self.index_map[index_item]] @@ -299,8 +309,8 @@ class SingleData(BaseData): def add(self, other, fill_value=0): common_index = list(set(self.index) | set(other.index)) - tmp_data1 = self.reindex(common_index,fill_value) - tmp_data2 = other.reindex(common_index,fill_value) + tmp_data1 = self.reindex(common_index, fill_value) + tmp_data2 = other.reindex(common_index, fill_value) return tmp_data1 + tmp_data2 def to_dict(self): @@ -324,7 +334,7 @@ class SingleData(BaseData): return MultiData(self.data[:, np.newaxis], self.index) def to_pd_series(self): - return pd.Series(self.data, index = self.index) + return pd.Series(self.data, index=self.index) def __getitem__(self, index: Union["SingleData", int, str]): if isinstance(index, int): @@ -340,7 +350,12 @@ class SingleData(BaseData): class MultiData(BaseData): - def __init__(self, data: Union[list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []): + def __init__( + self, + data: Union[list, np.ndarray] = [[]], + index: Union[list, pd.Index] = [], + columns: Union[list, pd.Index] = [], + ): """A data structure of index and numpy data. It's used to replace pd.DataFrame due to high-speed. @@ -363,12 +378,12 @@ class MultiData(BaseData): # data in SingleData must be two dim assert self.data.ndim == 2 # replace int with float - if self.data.dtype == np.int: + if self.data.dtype == np.signedinteger: self.data = self.data.astype(np.float64) # replace None with np.NaN, because pd.DataFrame does it. if None in self.data: self.data[self.data == None] = np.NaN - + # index if isinstance(index, list): if index == [] and self.data.shape[0] > 0: @@ -379,7 +394,7 @@ class MultiData(BaseData): else: raise ValueError(f"index must be list or pd.Index") assert self.data.shape[0] == len(self.index) - # if data is not empty, + # if data is not empty, self.index_map = dict(zip(self.index, range(len(self.index)))) # columns @@ -392,19 +407,29 @@ class MultiData(BaseData): else: raise ValueError(f"columns must be list or pd.Index") assert self.data.shape[1] == len(self.columns) - # if data is not empty, - self.columns_map = dict(zip(self.columns, range(len(self.columns)))) + # if data is not empty, + self.columns_map = dict(zip(self.columns, range(len(self.columns)))) super(MultiData, self).__init__() def _align_index(self, other): if self.index_columns == other.index_columns: - return self, other + return self, other else: - raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations") + raise ValueError( + f"The indexes of self and other do not meet the requirements of the four arithmetic operations" + ) def __getitem__(self, col) -> SingleData: if col not in self.columns: return SingleData() else: return SingleData(self.data[:, self.columns_map[col]], self.index) + + def loc(self, start, end, col=None): + start_id = bisect.bisect_left(self.index, start) + end_id = bisect.bisect_right(self.index, end) + if col is None: + return MultiData(self.data[start_id:end_id], self.index[start_id:end_id], self.columns) + else: + return SingleData(self.data[start_id:end_id, self.columns_map[col]], self.index[start_id:end_id]) diff --git a/qlib/utils/time.py b/qlib/utils/time.py index efee8f5eb..e9ae82c5f 100644 --- a/qlib/utils/time.py +++ b/qlib/utils/time.py @@ -38,7 +38,7 @@ def get_min_cal(shift: int = 0) -> List[time]: return cal -def _if_single_data(start_time, end_time, freq): +def if_single_data(start_time, end_time, freq): """Is there only one piece of data to obtain. Parameters