From b6564cd7600ac185630eea533c95c5254da8cb06 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 24 Jun 2021 19:09:36 +0000 Subject: [PATCH] support trade decision update --- qlib/backtest/executor.py | 9 +- qlib/backtest/utils.py | 118 ++++++++++++++++- qlib/contrib/strategy/model_strategy.py | 6 +- qlib/contrib/strategy/order_generator.py | 6 +- qlib/contrib/strategy/rule_strategy.py | 155 ++++++++++++----------- qlib/strategy/base.py | 23 +++- 6 files changed, 228 insertions(+), 89 deletions(-) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 226f112b7..5cc2c00c3 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -5,7 +5,7 @@ from typing import Union from .order import Order from .exchange import Exchange -from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure +from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison from ..utils import init_instance_by_config from ..utils.resam import parse_freq @@ -135,7 +135,7 @@ class BaseExecutor: Parameters ---------- - trade_decision : object + trade_decision : TradeDecison Returns ---------- @@ -149,7 +149,7 @@ class BaseExecutor: Parameters ---------- - trade_decision : object + trade_decision : TradeDecison Returns ---------- @@ -352,7 +352,8 @@ class SimulatorExecutor(BaseExecutor): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) execute_result = [] - for order in trade_decision: + order_generator = trade_decision.generator() + for order in order_generator: if self.trade_exchange.check_order(order) is True: # execute the order trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 25ddc45a4..120f80609 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from re import L import pandas as pd import warnings -from typing import Union +from typing import Union, List, Set from ..utils.resam import get_resam_calendar from ..data.data import Cal @@ -145,3 +146,118 @@ class CommonInfrastructure(BaseInfrastructure): class LevelInfrastructure(BaseInfrastructure): def get_support_infra(self): return ["trade_calendar"] + + +class TradeDecison: + """trade decison that made by strategy""" + + def __init__(self, order_list, ori_strategy, init_enable=False): + """ + Parameters + ---------- + order_list : list + the order list + ori_strategy : BaseStrategy + the original strategy that make the decison + init_enable : bool, optional + wether to enable order initially, default by False + """ + self.order_list = order_list + self.ori_strategy = ori_strategy + if init_enable: + self.enable_dict = {_order.stock_id: _order for _order in self.order_list} + self.disable_dict = dict() + else: + self.enable_dict = dict() + self.disable_dict = {_order.stock_id: _order for _order in self.order_list} + + def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False): + """enable order set + Parameters + ---------- + enable_set : Union[List[str], Set[str]], optional + the order set that will be enabled, by default None + - if all_enable is True, enable_set will be ignored + - else, enable the order whose stock_id in enable_set + all_enable : bool, optional + wether to enable all order, by default False + """ + if all_enable is True: + self.enable_dict.update(self.disable_dict) + self.disable_dict.clear() + if enable_set is not None: + warnings.warn(f"`enable_set` is ignored because `all_enable` is set True") + else: + enable_set = set(enable_set) + for _stock_id in enable_set: + enable_order = self.disable_dict.get(_stock_id) + if enable_order is None: + raise ValueError(f"_stock_id {_stock_id} is not found in disable set") + self.enable_order.update({_stock_id: enable_order}) + self.disable_dict.pop(_stock_id) + + def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False): + """disable order set + Parameters + ---------- + disable_set : Union[List[str], Set[str]], optional + the order set that will be disabled, by default None + - if all_disable is True, disable_set will be ignored + - else, disable the order whose stock_id in disable_set + all_disable : bool, optional + wether to disable all order, by default False + """ + if all_disable is True: + self.disable_dict.update(self.enable_dict) + self.enable_dict.clear() + if disable_set is not None: + warnings.warn(f"`disable_set` is ignored because `all_disable` is set True") + else: + disable_set = set(disable_set) + for _stock_id in disable_set: + disable_order = self.enable_dict.get(_stock_id) + if disable_order is None: + raise ValueError(f"_stock_id {_stock_id} is not found in enable set") + self.disable_dict.update({_stock_id: disable_order}) + self.enable_dict.pop(_stock_id) + + def generator(self, only_enable=False, only_disable=False): + """get order generator used for iteration + Parameters + ---------- + only_enable : bool, optional + wether to ignore disabled order, by default False + only_disable : bool, optional + wether to ignore enabled order, by default False + """ + if not only_disable and not only_enable: + yield from self.order_list + elif not only_disable: + yield from self.enable_dict.values() + elif not only_enable: + yield from self.disable_dict.values() + + def get_order_list(self, only_enable=False, only_disable=False): + """get the order list + + Parameters + ---------- + only_enable : bool, optional + wether to ignore disabled order, by default False + only_disable : bool, optional + wether to ignore enabled order, by default False + Returns + ------- + List[Order] + the order list + """ + if not only_disable and not only_enable: + return self.order_list + elif not only_disable: + return list(self.enable_dict.values()) + elif not only_enable: + return list(self.disable_dict.values()) + + def update(self, trade_step, trade_len): + """make the original strategy update the enabled status of orders.""" + self.ori_strategy.update_trade_decision(self, trade_step, trade_len) diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index d88dcd7d6..679385043 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -6,6 +6,8 @@ import pandas as pd from ...utils.resam import resam_ts_data from ...strategy.base import ModelStrategy from ...backtest.order import Order +from ...backtest.utils import TradeDecison + from .order_generator import OrderGenWInteract @@ -244,7 +246,7 @@ class TopkDropoutStrategy(ModelStrategy): factor=factor, ) buy_order_list.append(buy_order) - return sell_order_list + buy_order_list + return TradeDecison(order_list=sell_order_list + buy_order_list, ori_strategy=self) class WeightStrategyBase(ModelStrategy): @@ -339,4 +341,4 @@ class WeightStrategyBase(ModelStrategy): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index d3e94551a..7e4ee1a07 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -6,6 +6,8 @@ This order generator is for strategies based on WeightStrategyBase """ from ...backtest.position import Position from ...backtest.exchange import Exchange +from ...backtest.utils import TradeDecison + import pandas as pd import copy @@ -125,7 +127,7 @@ class OrderGenWInteract(OrderGenerator): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) class OrderGenWOInteract(OrderGenerator): @@ -189,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 9f0cca8c8..01eb42803 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -9,7 +9,7 @@ from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy from ...backtest.order import Order from ...backtest.exchange import Exchange -from ...backtest.utils import CommonInfrastructure, LevelInfrastructure +from ...backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison class TWAPStrategy(BaseStrategy): @@ -17,7 +17,7 @@ class TWAPStrategy(BaseStrategy): def __init__( self, - outer_trade_decision: List[Order] = None, + outer_trade_decision: TradeDecison = None, trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, @@ -25,8 +25,8 @@ class TWAPStrategy(BaseStrategy): """ Parameters ---------- - outer_trade_decision : List[Order] - the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy + outer_trade_decision : TradeDecison + the trade decison of outer strategy which this startegy relies trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -57,33 +57,37 @@ class TWAPStrategy(BaseStrategy): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: List[Order] = None, **kwargs): + def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): """ Parameters ---------- - outer_trade_decision : List[Order], optional + outer_trade_decision : TradeDecison, optional """ super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_amount = {} - for order in outer_trade_decision: - self.trade_amount[(order.stock_id, order.direction)] = order.amount + outer_order_generator = outer_trade_decision.generator() + for order in outer_order_generator: + self.trade_amount[order.stock_id] = order.amount def generate_trade_decision(self, execute_result=None): - - # update the order amount - if execute_result is not None: - for order, _, _, _ in execute_result: - self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step trade_len = self.trade_calendar.get_trade_len() + # update outer trade decision + self.outer_trade_decision.update(trade_step, trade_len) + + # update the order amount + if execute_result is not None: + for order, _, _, _ in execute_result: + self.trade_amount[order.stock_id] -= order.deal_amount + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] - for order in self.outer_trade_decision: + outer_order_generator = self.outer_trade_decision.generator(only_enable=True) + for order in outer_order_generator: # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time @@ -94,12 +98,12 @@ class TWAPStrategy(BaseStrategy): # considering trade unit if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) # without considering trade unit else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade - trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) _order_amount = ( @@ -108,12 +112,10 @@ class TWAPStrategy(BaseStrategy): if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount < 1e-5 or trade_step == trade_len - 1 - ): - _order_amount = self.trade_amount[(order.stock_id, order.direction)] + if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1): + _order_amount = self.trade_amount[order.stock_id] - _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) + _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) if _order_amount > 1e-5: @@ -126,7 +128,7 @@ class TWAPStrategy(BaseStrategy): factor=order.factor, ) order_list.append(_order) - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) class SBBStrategyBase(BaseStrategy): @@ -140,7 +142,7 @@ class SBBStrategyBase(BaseStrategy): def __init__( self, - outer_trade_decision: List[Order] = None, + outer_trade_decision: TradeDecison = None, trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, @@ -148,8 +150,8 @@ class SBBStrategyBase(BaseStrategy): """ Parameters ---------- - outer_trade_decision : List[Order] - the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase + outer_trade_decision : TradeDecison + the trade decison of outer strategy which this startegy relies trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -178,52 +180,57 @@ class SBBStrategyBase(BaseStrategy): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: List[Order] = None, **kwargs): + def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): """ Parameters ---------- - outer_trade_decision : List[Order], optional + outer_trade_decision : TradeDecison, optional """ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_trend = {} self.trade_amount = {} # init the trade amount of order and predicted trade trend - for order in outer_trade_decision: - self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID - self.trade_amount[(order.stock_id, order.direction)] = order.amount + outer_order_generator = outer_trade_decision.generator() + for order in outer_order_generator: + self.trade_trend[order.stock_id] = self.TREND_MID + self.trade_amount[order.stock_id] = order.amount def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") def generate_trade_decision(self, execute_result=None): - - # update the order amount - if execute_result is not None: - for order, _, _, _ in execute_result: - self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step trade_len = self.trade_calendar.get_trade_len() + # update outer trade decision + self.outer_trade_decision.update(trade_step, trade_len) + + # update the order amount + if execute_result is not None: + for order, _, _, _ in execute_result: + self.trade_amount[order.stock_id] -= order.deal_amount + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) order_list = [] # for each order in in self.outer_trade_decision - for order in self.outer_trade_decision: + outer_order_generator = self.outer_trade_decision.generator(only_enable=True) + for order in outer_order_generator: # get the price trend if trade_step % 2 == 0: # in the first of two adjacent bars, predict the price trend _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: # in the second of two adjacent bars, use the trend predicted in the first one - _pred_trend = self.trade_trend[(order.stock_id, order.direction)] + _pred_trend = self.trade_trend[order.stock_id] # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): if trade_step % 2 == 0: - self.trade_trend[(order.stock_id, order.direction)] = _pred_trend + self.trade_trend[order.stock_id] = _pred_trend continue # get amount of one trade unit _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) @@ -232,12 +239,12 @@ class SBBStrategyBase(BaseStrategy): # considering trade unit if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) # without considering trade unit else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade - trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( @@ -245,12 +252,12 @@ class SBBStrategyBase(BaseStrategy): ) if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( + if self.trade_amount[order.stock_id] > 1e-5 and ( _order_amount < 1e-5 or trade_step == trade_len - 1 ): - _order_amount = self.trade_amount[(order.stock_id, order.direction)] + _order_amount = self.trade_amount[order.stock_id] - _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) + _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) if _order_amount > 1e-5: _order = Order( @@ -268,13 +275,11 @@ class SBBStrategyBase(BaseStrategy): # considering trade unit if _amount_trade_unit is None: # N trade day left, divide the order into N + 1 parts, and trade 2 parts - _order_amount = ( - 2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1) - ) + _order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1) # without considering trade unit else: # cal how many trade unit - trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) # N trade day left, divide the order into N + 1 parts, and trade 2 parts _order_amount = ( (trade_unit_cnt + trade_len - trade_step) @@ -284,12 +289,12 @@ class SBBStrategyBase(BaseStrategy): ) if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( + if self.trade_amount[order.stock_id] > 1e-5 and ( _order_amount < 1e-5 or trade_step == trade_len - 1 ): - _order_amount = self.trade_amount[(order.stock_id, order.direction)] + _order_amount = self.trade_amount[order.stock_id] - _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) + _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) if _order_amount > 1e-5: if trade_step % 2 == 0: @@ -333,9 +338,9 @@ class SBBStrategyBase(BaseStrategy): if trade_step % 2 == 0: # in the first one of two adjacent bars, store the trend for the second one to use - self.trade_trend[(order.stock_id, order.direction)] = _pred_trend + self.trade_trend[order.stock_id] = _pred_trend - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) class SBBStrategyEMA(SBBStrategyBase): @@ -345,7 +350,7 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, - outer_trade_decision: List[Order] = None, + outer_trade_decision: TradeDecison = None, instruments: Union[List, str] = "csi300", freq: str = "day", trade_exchange: Exchange = None, @@ -425,7 +430,7 @@ class ACStrategy(BaseStrategy): lamb: float = 1e-6, eta: float = 2.5e-6, window_size: int = 20, - outer_trade_decision: List[Order] = None, + outer_trade_decision: TradeDecison = None, instruments: Union[List, str] = "csi300", freq: str = "day", trade_exchange: Exchange = None, @@ -502,34 +507,38 @@ class ACStrategy(BaseStrategy): self.trade_calendar = level_infra.get("trade_calendar") self._reset_signal() - def reset(self, outer_trade_decision: List[Order] = None, **kwargs): + def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): """ Parameters ---------- - outer_trade_decision : List[Order], optional + outer_trade_decision : TradeDecison, optional """ super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_amount = {} # init the trade amount of order and predicted trade trend - for order in outer_trade_decision: - self.trade_amount[(order.stock_id, order.direction)] = order.amount + outer_order_generator = outer_trade_decision.generator() + for order in outer_order_generator: + self.trade_amount[order.stock_id] = order.amount def generate_trade_decision(self, execute_result=None): - - # update the order amount - if execute_result is not None: - for order, _, _, _ in execute_result: - self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step trade_len = self.trade_calendar.get_trade_len() + # update outer trade decision + self.outer_trade_decision.update(trade_step, trade_len) + + # update the order amount + if execute_result is not None: + for order, _, _, _ in execute_result: + self.trade_amount[order.stock_id] -= order.deal_amount + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) order_list = [] - for order in self.outer_trade_decision: + outer_order_generator = self.outer_trade_decision.generator(only_enable=True) + for order in outer_order_generator: # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time @@ -549,11 +558,11 @@ class ACStrategy(BaseStrategy): _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade - trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( @@ -571,12 +580,10 @@ class ACStrategy(BaseStrategy): if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount < 1e-5 or trade_step == trade_len - 1 - ): - _order_amount = self.trade_amount[(order.stock_id, order.direction)] + if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1): + _order_amount = self.trade_amount[order.stock_id] - _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) + _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) if _order_amount > 1e-5: @@ -589,4 +596,4 @@ class ACStrategy(BaseStrategy): factor=order.factor, ) order_list.append(_order) - return order_list + return TradeDecison(order_list=order_list, ori_strategy=self) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 961fb5044..9f9feb3b1 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -7,7 +7,7 @@ from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config -from ..backtest.utils import CommonInfrastructure, LevelInfrastructure +from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison class BaseStrategy: @@ -15,14 +15,14 @@ class BaseStrategy: def __init__( self, - outer_trade_decision: object = None, + outer_trade_decision: TradeDecison = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, ): """ Parameters ---------- - outer_trade_decision : object, optional + outer_trade_decision : TradeDecison, optional the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None - If the strategy is used to split trade decison, it will be used - If the strategy is used for portfolio management, it can be ignored @@ -84,6 +84,17 @@ class BaseStrategy: """ raise NotImplementedError("generate_trade_decision is not implemented!") + def update_trade_decision(self, trade_decison: TradeDecison, trade_step, trade_len): + """update trade decision in each step of inner execution, this method enable all order + + Parameters + ---------- + trade_decison : TradeDecison + the trade decison that will be updated + """ + if trade_step == 0: + trade_decison.enable(all_enable=True) + class ModelStrategy(BaseStrategy): """Model-based trading strategy, use model to make predictions for trading""" @@ -92,7 +103,7 @@ class ModelStrategy(BaseStrategy): self, model: BaseModel, dataset: DatasetH, - outer_trade_decision: object = None, + outer_trade_decision: TradeDecison = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, @@ -128,7 +139,7 @@ class RLStrategy(BaseStrategy): def __init__( self, policy, - outer_trade_decision: object = None, + outer_trade_decision: TradeDecison = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, @@ -151,7 +162,7 @@ class RLIntStrategy(RLStrategy): policy, state_interpreter: Union[dict, StateInterpreter], action_interpreter: Union[dict, ActionInterpreter], - outer_trade_decision: object = None, + outer_trade_decision: TradeDecison = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs,