mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Fix the first trading day of the calendar extra in report_df
This commit is contained in:
@@ -69,7 +69,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
|
||||
|
||||
# trading apart
|
||||
|
||||
@@ -168,7 +168,7 @@ def get_exchange(
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
@@ -340,7 +340,7 @@ def long_short_backtest(
|
||||
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
long_returns = {}
|
||||
short_returns = {}
|
||||
|
||||
@@ -38,7 +38,7 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
:param df:
|
||||
:return:
|
||||
"""
|
||||
|
||||
index_names = df.index.names
|
||||
df.index = df.index.strftime("%Y-%m-%d")
|
||||
|
||||
report_df = pd.DataFrame()
|
||||
@@ -58,6 +58,8 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
report_df["turnover"] = df["turnover"]
|
||||
report_df.sort_index(ascending=True, inplace=True)
|
||||
|
||||
report_df.index.names = index_names
|
||||
return report_df
|
||||
|
||||
|
||||
|
||||
@@ -279,8 +279,10 @@ def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
def create_save_path(save_path=None):
|
||||
"""Create save path
|
||||
|
||||
:param save_path:
|
||||
:return:
|
||||
Parameters
|
||||
----------
|
||||
save_path: str
|
||||
|
||||
"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
@@ -471,30 +473,28 @@ def is_tradable_date(cur_date):
|
||||
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
|
||||
|
||||
|
||||
def get_date_range(trading_date, shift, future=False):
|
||||
def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
"""get trading date range by shift
|
||||
|
||||
:param trading_date:
|
||||
:param shift: int
|
||||
:param future: bool
|
||||
:return:
|
||||
Parameters
|
||||
----------
|
||||
trading_date: pd.Timestamp
|
||||
left_shift: int
|
||||
right_shift: int
|
||||
future: bool
|
||||
|
||||
"""
|
||||
|
||||
from ..data import D
|
||||
|
||||
calendar = D.calendar(future=future)
|
||||
if pd.to_datetime(trading_date) not in list(calendar):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
day_index = bisect.bisect_left(calendar, trading_date)
|
||||
if 0 <= (day_index + shift) < len(calendar):
|
||||
if shift > 0:
|
||||
return calendar[day_index + 1 : day_index + 1 + shift]
|
||||
else:
|
||||
return calendar[day_index + shift : day_index]
|
||||
else:
|
||||
return calendar
|
||||
start = get_date_by_shift(trading_date, left_shift, future=future)
|
||||
end = get_date_by_shift(trading_date, right_shift, future=future)
|
||||
|
||||
calendar = D.calendar(start, end, future=future)
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -502,8 +502,22 @@ def get_date_by_shift(trading_date, shift, future=False):
|
||||
trading_date : pandas.Timestamp
|
||||
current date
|
||||
shift : int
|
||||
clip_shift: bool
|
||||
|
||||
"""
|
||||
return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future)
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
shift_index = _index + shift
|
||||
if shift_index < 0 or shift_index >= len(cal):
|
||||
if clip_shift:
|
||||
shift_index = np.clip(shift_index, 0, len(cal) - 1)
|
||||
else:
|
||||
raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range")
|
||||
return cal[shift_index]
|
||||
|
||||
|
||||
def get_next_trading_date(trading_date, future=False):
|
||||
|
||||
Reference in New Issue
Block a user