From f7d30960c13bbb1ba4a92d6e69eeeee493b54af8 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 7 May 2021 00:10:44 +0800 Subject: [PATCH] update the internal bar strategy --- examples/highfreq/backtest/workflow.py | 26 +++- qlib/contrib/backtest/exchange.py | 6 + qlib/contrib/strategy/rule_strategy.py | 182 +++++++++++++++++-------- qlib/strategy/base.py | 2 +- 4 files changed, 150 insertions(+), 66 deletions(-) diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index f229425c2..786469d8b 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -83,7 +83,7 @@ if __name__ == "__main__": "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy", "kwargs": { - "step_bar": "day", + "step_bar": "week", "model": model, "dataset": dataset, "topk": 50, @@ -91,12 +91,28 @@ if __name__ == "__main__": }, }, "env": { - "class": "SimulatorEnv", + "class": "SplitEnv", "module_path": "qlib.contrib.backtest.env", "kwargs": { - "step_bar": "day", - "verbose": True, - "generate_report": True, + "step_bar": "week", + "sub_env": { + "class": "SimulatorEnv", + "module_path": "qlib.contrib.backtest.env", + "kwargs": { + "step_bar": "day", + "verbose": True, + "generate_report": True, + }, + }, + "sub_strategy": { + "class": "SBBStrategyEMA", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": { + "step_bar": "day", + "freq": "day", + "instruments": market, + }, + }, }, }, "backtest": { diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index a25b9b4a0..51f0dd68d 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -390,6 +390,12 @@ class Exchange: ) return value + def get_amount_of_trade_unit(self, factor): + if not self.trade_w_adj_price: + return self.trade_unit / factor + else: + return None + def round_amount_by_trade_unit(self, deal_amount, factor): """Parameter deal_amount : float, adjusted amount diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 5f5329257..45df94830 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -24,27 +24,45 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) TradingEnhancement.reset(self, trade_order_list=trade_order_list) + if trade_exchange: + self.trade_exchange = trade_exchange if trade_order_list: self.trade_amount = {} for order in self.trade_order_list: - self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len - if trade_exchange: - self.trade_exchange = trade_exchange + self.trade_amount[(order.stock_id, order.direction)] = order.amount def generate_order_list(self, **kwargs): super(TWAPStrategy, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) order_list = [] for order in self.trade_order_list: - _order = Order( - stock_id=order.stock_id, - amount=self.trade_amount[(order.stock_id, order.direction)], - start_time=trade_start_time, - end_time=trade_end_time, - direction=order.direction, # 1 for buy - factor=order.factor, - ) - order_list.append(_order) + if not self.trade_exchange.is_stock_tradable( + stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + ): + continue + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) + _order_amount = None + if _amount_trade_unit is None: + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( + self.trade_len - self.trade_index + ) + if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + _order_amount = ( + (trade_unit_cnt + self.trade_len - self.trade_index - 1) + // (self.trade_len - self.trade_index) + * _amount_trade_unit + ) + if _order_amount: + _order = Order( + stock_id=order.stock_id, + amount=_order_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=order.direction, # 1 for buy + factor=order.factor, + ) + order_list.append(_order) return order_list @@ -70,20 +88,22 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(SBBStrategyBase, self).reset(**kwargs) TradingEnhancement.reset(self, trade_order_list=trade_order_list) - if trade_order_list: - self.trade_amount = {} - self.trade_trend = {} - for order in self.trade_order_list: - self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len - self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID if trade_exchange: self.trade_exchange = trade_exchange + if trade_order_list is not None: + self.trade_trend = {} + self.trade_amount = {} + for order in self.trade_order_list: + self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID + self.trade_amount[(order.stock_id, order.direction)] = order.amount def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") def generate_order_list(self, **kwargs): super(SBBStrategyBase, self).step() + if not self.trade_order_list: + return [] trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) order_list = [] @@ -92,49 +112,91 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): _pred_trend = self._pred_price_trend(order.stock_id) else: _pred_trend = self.trade_trend[(order.stock_id, order.direction)] - if _pred_trend == self.TREND_MID: - _order = Order( - stock_id=order.stock_id, - amount=self.trade_amount[(order.stock_id, order.direction)], - start_time=trade_start_time, - end_time=trade_end_time, - direction=order.direction, # 1 for buy - factor=order.factor, - ) - order_list.append(_order) - else: + + if not self.trade_exchange.is_stock_tradable( + stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + ): if self.trade_index % 2 == 1: - if ( - _pred_trend == self.TREND_SHORT - and order.direction == order.SELL - or _pred_trend == self.TREND_LONG - and order.direction == order.BUY - ): - _order = Order( - stock_id=order.stock_id, - amount=2 * self.trade_amount[(order.stock_id, order.direction)], - start_time=trade_start_time, - end_time=trade_end_time, - direction=order.direction, # 1 for buy - factor=order.factor, - ) - order_list.append(_order) - else: - if ( - _pred_trend == self.TREND_SHORT - and order.direction == order.BUY - or _pred_trend == self.TREND_LONG - and order.direction == order.SELL - ): - _order = Order( - stock_id=order.stock_id, - amount=2 * self.trade_amount[(order.stock_id, order.direction)], - start_time=trade_start_time, - end_time=trade_end_time, - direction=order.direction, # 1 for buy - factor=order.factor, - ) - order_list.append(_order) + self.trade_trend[(order.stock_id, order.direction)] = _pred_trend + continue + + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) + if _pred_trend == self.TREND_MID: + _order_amount = None + if _amount_trade_unit is None: + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( + self.trade_len - self.trade_index + ) + if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + _order_amount = ( + (trade_unit_cnt + self.trade_len - self.trade_index - 1) + // (self.trade_len - self.trade_index) + * _amount_trade_unit + ) + + if _order_amount: + self.trade_amount[(order.stock_id, order.direction)] -= _order_amount + _order = Order( + stock_id=order.stock_id, + amount=_order_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=order.direction, # 1 for buy + factor=order.factor, + ) + order_list.append(_order) + else: + _order_amount = None + if _amount_trade_unit is None: + _order_amount = ( + 2 + * self.trade_amount[(order.stock_id, order.direction)] + / (self.trade_len - self.trade_index + 1) + ) + if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + _order_amount = ( + 2 + * (trade_unit_cnt + self.trade_len - self.trade_index) + // (self.trade_len - self.trade_index + 1) + * _amount_trade_unit + ) + if _order_amount: + _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) + self.trade_amount[(order.stock_id, order.direction)] -= _order_amount + if self.trade_index % 2 == 1: + if ( + _pred_trend == self.TREND_SHORT + and order.direction == order.SELL + or _pred_trend == self.TREND_LONG + and order.direction == order.BUY + ): + _order = Order( + stock_id=order.stock_id, + amount=_order_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=order.direction, # 1 for buy + factor=order.factor, + ) + order_list.append(_order) + else: + if ( + _pred_trend == self.TREND_SHORT + and order.direction == order.BUY + or _pred_trend == self.TREND_LONG + and order.direction == order.SELL + ): + _order = Order( + stock_id=order.stock_id, + amount=_order_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=order.direction, # 1 for buy + factor=order.factor, + ) + order_list.append(_order) if self.trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index e5840d66a..8a857eb00 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -51,5 +51,5 @@ class ModelStrategy(BaseStrategy): class TradingEnhancement: def reset(self, trade_order_list=None): - if trade_order_list: + if trade_order_list is not None: self.trade_order_list = trade_order_list