diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 34c0ef744..ccd5f4b45 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -12,7 +12,7 @@ import pandas as pd from ..data.data import D from ..data.dataset.utils import get_level_index from ..config import C, REG_CN -from ..utils.resam import resam_ts_data +from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger from .order import Order, OrderDir, OrderHelper @@ -166,7 +166,7 @@ class Exchange: quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): - quote_dict[stock_id] = stock_val + quote_dict[stock_id] = stock_val.droplevel(level="instrument") self.quote = quote_dict @@ -186,13 +186,13 @@ class Exchange: """ if direction is None: - buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] - sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") + sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") return buy_limit or sell_limit elif direction == Order.BUY: - return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") elif direction == Order.SELL: - return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") else: raise ValueError(f"direction {direction} is not supported!") @@ -258,16 +258,16 @@ class Exchange: return trade_val, trade_cost, trade_price def get_quote_info(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0] + return resam_ts_data(self.quote[stock_id], start_time, end_time, method=ts_data_last) def get_close(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0] + return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=ts_data_last) def get_volume(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0] + return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum") def get_deal_price(self, stock_id, start_time, end_time): - deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0] + deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last) if np.isclose(deal_price, 0.0) or np.isnan(deal_price): self.logger.warning( f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!" @@ -286,10 +286,7 @@ class Exchange: """ if stock_id not in self.quote: return None - res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last") - if res is not None: - res = res.iloc[0] - return res + return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last) def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index f217ea169..7623af551 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -91,7 +91,7 @@ class Report: if freq is None: raise ValueError("benchmark freq can't be None!") - _codes = benchmark if isinstance(benchmark, list) else [benchmark] + _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark] fields = ["$close/Ref($close,1)-1"] _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq) if len(_temp_result) == 0: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 2bc01045d..d18eb2a27 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -7,7 +7,7 @@ from qlib.data.dataset.utils import convert_index_format from qlib.utils import lazy_sort_index -from ...utils.resam import resam_ts_data +from ...utils.resam import resam_ts_data, ts_data_last from ...data.data import D from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO @@ -432,7 +432,7 @@ class SBBStrategyEMA(SBBStrategyBase): if not signal_df.empty: for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument") def reset_level_infra(self, level_infra): """ @@ -454,13 +454,16 @@ class SBBStrategyEMA(SBBStrategyBase): return self.TREND_MID else: _sample_signal = resam_ts_data( - self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last" + self.signal[stock_id], + pred_start_time, + pred_end_time, + method=ts_data_last, ) # if EMA signal == 0 or None, return mid trend - if _sample_signal is None or _sample_signal.iloc[0] == 0: + if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0: return self.TREND_MID # if EMA signal > 0, return long trend - elif _sample_signal.iloc[0] > 0: + elif _sample_signal > 0: return self.TREND_LONG # if EMA signal < 0, return short trend else: @@ -523,7 +526,7 @@ class ACStrategy(BaseStrategy): if not signal_df.empty: for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument") def reset_common_infra(self, common_infra): """ @@ -590,12 +593,12 @@ class ACStrategy(BaseStrategy): # considering trade unit sig_sam = ( - resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last") + resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last) if order.stock_id in self.signal else None ) - if sig_sam is None or sig_sam.iloc[0] is None: + if sig_sam is None or np.isnan(sig_sam): # no signal, TWAP _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) if _amount_trade_unit is None: @@ -612,7 +615,7 @@ class ACStrategy(BaseStrategy): ) else: # VA strategy - kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0] + kappa_tild = self.lamb / self.eta * sig_sam * sig_sam kappa = np.arccosh(kappa_tild / 2 + 1) amount_ratio = ( np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1)) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index edcc1ede2..2d5159292 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -197,7 +197,7 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._data if isinstance(data_storage, pd.DataFrame): @@ -211,10 +211,17 @@ class DataHandler(Serializable): # Fetch column first will be more friendly to SepDataFrame data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") @@ -522,7 +529,7 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._get_df_by_key(data_key) if isinstance(data_storage, pd.DataFrame): @@ -537,10 +544,17 @@ class DataHandlerLP(DataHandler): data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 247970481..9325807f9 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -14,6 +14,7 @@ class BaseHandlerStorage: level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, + proc_func: Callable = None, **kwargs, ) -> pd.DataFrame: """fetch data from the data storage @@ -24,6 +25,7 @@ class BaseHandlerStorage: describe how to select data by index level : Union[str, int] which index level to select the data + - if level is None, apply selector to df directly col_set : Union[str, List[str]] - if isinstance(col_set, str): select a set of meaningful columns.(e.g. features, columns) @@ -33,15 +35,24 @@ class BaseHandlerStorage: select several sets of meaningful columns, the returned data has multiple level fetch_orig : bool Return the original data instead of copy if possible. + proc_func: Callable + please refer to the doc of DataHandler.fetch + Returns + ------- + pd.DataFrame + the dataframe fetched """ - raise NotImplementedError("fetch is method not implemented!") @staticmethod def from_df(df: pd.DataFrame): raise NotImplementedError("from_df method is not implemented!") + def is_proc_func_supported(self): + """whether the arg `proc_func` in `fetch` method is supported.""" + raise NotImplementedError("is_proc_func_supported method is not implemented!") + class HasingStockStorage(BaseHandlerStorage): def __init__(self, df): @@ -105,3 +116,7 @@ class HasingStockStorage(BaseHandlerStorage): return fetch_stock_df_list[0] else: return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) + + def is_proc_func_supported(self): + """the arg `proc_func` in `fetch` method is not supported in HasingStockStorage""" + return False diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 4df155946..9e9590e30 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -3,6 +3,8 @@ import datetime import numpy as np import pandas as pd + +from functools import partial from typing import Tuple, List, Union, Optional, Callable from . import lazy_sort_index @@ -263,3 +265,36 @@ def resam_ts_data( elif isinstance(method, str): return getattr(feature, method)(**method_kwargs) return feature + + +def get_valid_value(series, last=True): + """get the first/last not nan value of pd.Series with single level index + Parameters + ---------- + series : pd.Seires + series should not be empty + last : bool, optional + wether to get the last valid value, by default True + - if last is True, get the last valid value + - else, get the first valid value + + Returns + ------- + Nan | float + the first/last valid value + """ + return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0] + + +def _ts_data_valid(ts_feature, last=False): + """get the first/last not nan value of pd.Series|DataFrame with single level index""" + if isinstance(ts_feature, pd.DataFrame): + return ts_feature.apply(lambda column: get_valid_value(column, last=last)) + elif isinstance(ts_feature, pd.Series): + return get_valid_value(ts_feature, last=last) + else: + raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") + + +ts_data_last = partial(_ts_data_valid, last=False) +ts_data_first = partial(_ts_data_valid, last=True)