diff --git a/qlib/contrib/backtest/executor.py b/qlib/contrib/backtest/executor.py index ef0f205ce..943b26f9c 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/contrib/backtest/executor.py @@ -74,11 +74,12 @@ class BaseTradeCalendar: return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1) def finished(self): - return self.trade_index >= self.trade_len - 1 + return self.trade_index >= self.trade_len def step(self): if self.finished(): raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!") + # trade count += 1 self.trade_index = self.trade_index + 1 @@ -165,6 +166,7 @@ class SplitExecutor(BaseExecutor): trading strategy in each trading bar trade_exchange : Exchange exchange that provides market info + - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty """ super(SplitExecutor, self).__init__( step_bar=step_bar, diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 3a37d71d3..0e0f2b907 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -2,7 +2,7 @@ import copy import warnings import numpy as np import pandas as pd - +from typing import Union from ...utils.sample import sample_feature from ...data.data import D @@ -13,6 +13,8 @@ from ..backtest.faculty import common_faculty class TWAPStrategy(RuleStrategy, OrderEnhancement): + """TWAP Strategy for trading""" + def __init__( self, step_bar, @@ -22,6 +24,15 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): trade_order_list=[], **kwargs, ): + """ + Parameters + ---------- + trade_exchange : Exchange, optional + exchange that provides market info, by default None + - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty + trade_order_list : list, optional + order list to trade, which the strategy will trade in [start_time , end_time] , by default [] + """ super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange self.trade_order_list = trade_order_list @@ -51,19 +62,19 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): _order_amount = None if _amount_trade_unit is None: _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( - self.trade_len - self.trade_index + self.trade_len - self.trade_index + 1 ) if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index - 1) - // (self.trade_len - self.trade_index) + (trade_unit_cnt + self.trade_len - self.trade_index) + // (self.trade_len - self.trade_index + 1) * _amount_trade_unit ) if order.direction == order.SELL: if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len - 1 + _order_amount is None or self.trade_index == self.trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -99,6 +110,15 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): trade_order_list=[], **kwargs, ): + """ + Parameters + ---------- + trade_exchange : Exchange, optional + exchange that provides market info, by default None + - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty + trade_order_list : list, optional + order list to trade, which the strategy will trade in [start_time , end_time] , by default [] + """ super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs) self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange self.trade_order_list = trade_order_list @@ -144,18 +164,18 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): _order_amount = None if _amount_trade_unit is None: _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( - self.trade_len - self.trade_index + self.trade_len - self.trade_index + 1 ) elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index - 1) - // (self.trade_len - self.trade_index) + (trade_unit_cnt + self.trade_len - self.trade_index) + // (self.trade_len - self.trade_index + 1) * _amount_trade_unit ) if order.direction == order.SELL: if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len - 1 + _order_amount is None or self.trade_index == self.trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -176,19 +196,19 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): _order_amount = ( 2 * self.trade_amount[(order.stock_id, order.direction)] - / (self.trade_len - self.trade_index + 1) + / (self.trade_len - self.trade_index + 2) ) elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index) - // (self.trade_len - self.trade_index + 1) + (trade_unit_cnt + self.trade_len - self.trade_index + 1) + // (self.trade_len - self.trade_index + 2) * 2 * _amount_trade_unit ) if order.direction == order.SELL: if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len - 1 + _order_amount is None or self.trade_index == self.trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -235,7 +255,7 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): class SBBStrategyEMA(SBBStrategyBase): """ - (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA). + (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal. """ def __init__( @@ -249,6 +269,15 @@ class SBBStrategyEMA(SBBStrategyBase): freq="day", **kwargs, ): + """ + Parameters + ---------- + instruments : str, optional + instruments of EMA signal, by default "csi300" + freq : str, optional + freq of EMA signal, by default "day" + Note: `freq` may be different from `steb_bar` + """ super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs) if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") @@ -257,13 +286,25 @@ class SBBStrategyEMA(SBBStrategyBase): self.instruments = D.instruments(instruments) self.freq = freq - def reset(self, start_time=None, end_time=None, **kwargs): + def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs): + """ + Reset EMA signal for trading + + Parameters + ---------- + start_time : Union[str, pd.Timestamp], optional + start time for trading, also used to calculate the start time of EMA signal, by default None + + end_time : Union[str, pd.Timestamp], optional + end time for trading, also used to calculate the end time of EMA signal, by default None + """ super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs) if self.start_time and self.end_time and (start_time or end_time): fields = ["EMA($close, 10)-EMA($close, 20)"] - signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) + signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1) + _, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1) signal_df = D.features( - self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq + self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] @@ -272,6 +313,7 @@ class SBBStrategyEMA(SBBStrategyBase): self.signal[stock_id] = stock_val def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): + if stock_id not in self.signal: return self.TREND_MID else: