1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Fix the first trading day of the calendar extra in report_df

This commit is contained in:
zhupr
2020-12-15 23:41:14 +08:00
committed by you-n-g
parent 660edeb94f
commit a0f32036a6
4 changed files with 40 additions and 24 deletions

View File

@@ -69,7 +69,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
# trading apart

View File

@@ -168,7 +168,7 @@ def get_exchange(
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
exchange = Exchange(
trade_dates=dates,
@@ -340,7 +340,7 @@ def long_short_backtest(
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
long_returns = {}
short_returns = {}

View File

@@ -38,7 +38,7 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
:param df:
:return:
"""
index_names = df.index.names
df.index = df.index.strftime("%Y-%m-%d")
report_df = pd.DataFrame()
@@ -58,6 +58,8 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
report_df["turnover"] = df["turnover"]
report_df.sort_index(ascending=True, inplace=True)
report_df.index.names = index_names
return report_df

View File

@@ -279,8 +279,10 @@ def compare_dict_value(src_data: dict, dst_data: dict):
def create_save_path(save_path=None):
"""Create save path
:param save_path:
:return:
Parameters
----------
save_path: str
"""
if save_path:
if not os.path.exists(save_path):
@@ -471,30 +473,28 @@ def is_tradable_date(cur_date):
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
def get_date_range(trading_date, shift, future=False):
def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
"""get trading date range by shift
:param trading_date:
:param shift: int
:param future: bool
:return:
Parameters
----------
trading_date: pd.Timestamp
left_shift: int
right_shift: int
future: bool
"""
from ..data import D
calendar = D.calendar(future=future)
if pd.to_datetime(trading_date) not in list(calendar):
raise ValueError("{} is not trading day!".format(str(trading_date)))
day_index = bisect.bisect_left(calendar, trading_date)
if 0 <= (day_index + shift) < len(calendar):
if shift > 0:
return calendar[day_index + 1 : day_index + 1 + shift]
else:
return calendar[day_index + shift : day_index]
else:
return calendar
start = get_date_by_shift(trading_date, left_shift, future=future)
end = get_date_by_shift(trading_date, right_shift, future=future)
calendar = D.calendar(start, end, future=future)
return calendar
def get_date_by_shift(trading_date, shift, future=False):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -502,8 +502,22 @@ def get_date_by_shift(trading_date, shift, future=False):
trading_date : pandas.Timestamp
current date
shift : int
clip_shift: bool
"""
return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date
from qlib.data import D
cal = D.calendar(future=future)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
shift_index = _index + shift
if shift_index < 0 or shift_index >= len(cal):
if clip_shift:
shift_index = np.clip(shift_index, 0, len(cal) - 1)
else:
raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range")
return cal[shift_index]
def get_next_trading_date(trading_date, future=False):