From a0f32036a61416bd4db9e6aa248518c0c82fbccc Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 15 Dec 2020 23:41:14 +0800 Subject: [PATCH] Fix the first trading day of the calendar extra in report_df --- qlib/contrib/backtest/backtest.py | 2 +- qlib/contrib/evaluate.py | 4 +- .../report/analysis_position/report.py | 4 +- qlib/utils/__init__.py | 54 ++++++++++++------- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 7ee8dceb0..2e785357c 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -69,7 +69,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark) raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean() - trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift)) + trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) executor = SimulatorExecutor(trade_exchange, verbose=verbose) # trading apart diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 4bb5e4372..a7b715321 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -168,7 +168,7 @@ def get_exchange( codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks dates = sorted(pred.index.get_level_values("datetime").unique()) - dates = np.append(dates, get_date_range(dates[-1], shift=shift)) + dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift)) exchange = Exchange( trade_dates=dates, @@ -340,7 +340,7 @@ def long_short_backtest( _pred_dates = pred.index.get_level_values(level="datetime") predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) - trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift)) + trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) long_returns = {} short_returns = {} diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index 438aab8b9..f82e654c4 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -38,7 +38,7 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame: :param df: :return: """ - + index_names = df.index.names df.index = df.index.strftime("%Y-%m-%d") report_df = pd.DataFrame() @@ -58,6 +58,8 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame: report_df["turnover"] = df["turnover"] report_df.sort_index(ascending=True, inplace=True) + + report_df.index.names = index_names return report_df diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index ddc17c478..a5a4b4a56 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -279,8 +279,10 @@ def compare_dict_value(src_data: dict, dst_data: dict): def create_save_path(save_path=None): """Create save path - :param save_path: - :return: + Parameters + ---------- + save_path: str + """ if save_path: if not os.path.exists(save_path): @@ -471,30 +473,28 @@ def is_tradable_date(cur_date): return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date()) -def get_date_range(trading_date, shift, future=False): +def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): """get trading date range by shift - :param trading_date: - :param shift: int - :param future: bool - :return: + Parameters + ---------- + trading_date: pd.Timestamp + left_shift: int + right_shift: int + future: bool + """ + from ..data import D - calendar = D.calendar(future=future) - if pd.to_datetime(trading_date) not in list(calendar): - raise ValueError("{} is not trading day!".format(str(trading_date))) - day_index = bisect.bisect_left(calendar, trading_date) - if 0 <= (day_index + shift) < len(calendar): - if shift > 0: - return calendar[day_index + 1 : day_index + 1 + shift] - else: - return calendar[day_index + shift : day_index] - else: - return calendar + start = get_date_by_shift(trading_date, left_shift, future=future) + end = get_date_by_shift(trading_date, right_shift, future=future) + + calendar = D.calendar(start, end, future=future) + return calendar -def get_date_by_shift(trading_date, shift, future=False): +def get_date_by_shift(trading_date, shift, future=False, clip_shift=True): """get trading date with shift bias wil cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -502,8 +502,22 @@ def get_date_by_shift(trading_date, shift, future=False): trading_date : pandas.Timestamp current date shift : int + clip_shift: bool + """ - return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date + from qlib.data import D + + cal = D.calendar(future=future) + if pd.to_datetime(trading_date) not in list(cal): + raise ValueError("{} is not trading day!".format(str(trading_date))) + _index = bisect.bisect_left(cal, trading_date) + shift_index = _index + shift + if shift_index < 0 or shift_index >= len(cal): + if clip_shift: + shift_index = np.clip(shift_index, 0, len(cal) - 1) + else: + raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range") + return cal[shift_index] def get_next_trading_date(trading_date, future=False):