1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

support trade decision update

This commit is contained in:
bxdd
2021-06-24 19:09:36 +00:00
parent 1517a9eb91
commit b6564cd760
6 changed files with 228 additions and 89 deletions

View File

@@ -5,7 +5,7 @@ from typing import Union
from .order import Order
from .exchange import Exchange
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison
from ..utils import init_instance_by_config
from ..utils.resam import parse_freq
@@ -135,7 +135,7 @@ class BaseExecutor:
Parameters
----------
trade_decision : object
trade_decision : TradeDecison
Returns
----------
@@ -149,7 +149,7 @@ class BaseExecutor:
Parameters
----------
trade_decision : object
trade_decision : TradeDecison
Returns
----------
@@ -352,7 +352,8 @@ class SimulatorExecutor(BaseExecutor):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
execute_result = []
for order in trade_decision:
order_generator = trade_decision.generator()
for order in order_generator:
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(

View File

@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from re import L
import pandas as pd
import warnings
from typing import Union
from typing import Union, List, Set
from ..utils.resam import get_resam_calendar
from ..data.data import Cal
@@ -145,3 +146,118 @@ class CommonInfrastructure(BaseInfrastructure):
class LevelInfrastructure(BaseInfrastructure):
def get_support_infra(self):
return ["trade_calendar"]
class TradeDecison:
"""trade decison that made by strategy"""
def __init__(self, order_list, ori_strategy, init_enable=False):
"""
Parameters
----------
order_list : list
the order list
ori_strategy : BaseStrategy
the original strategy that make the decison
init_enable : bool, optional
wether to enable order initially, default by False
"""
self.order_list = order_list
self.ori_strategy = ori_strategy
if init_enable:
self.enable_dict = {_order.stock_id: _order for _order in self.order_list}
self.disable_dict = dict()
else:
self.enable_dict = dict()
self.disable_dict = {_order.stock_id: _order for _order in self.order_list}
def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False):
"""enable order set
Parameters
----------
enable_set : Union[List[str], Set[str]], optional
the order set that will be enabled, by default None
- if all_enable is True, enable_set will be ignored
- else, enable the order whose stock_id in enable_set
all_enable : bool, optional
wether to enable all order, by default False
"""
if all_enable is True:
self.enable_dict.update(self.disable_dict)
self.disable_dict.clear()
if enable_set is not None:
warnings.warn(f"`enable_set` is ignored because `all_enable` is set True")
else:
enable_set = set(enable_set)
for _stock_id in enable_set:
enable_order = self.disable_dict.get(_stock_id)
if enable_order is None:
raise ValueError(f"_stock_id {_stock_id} is not found in disable set")
self.enable_order.update({_stock_id: enable_order})
self.disable_dict.pop(_stock_id)
def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False):
"""disable order set
Parameters
----------
disable_set : Union[List[str], Set[str]], optional
the order set that will be disabled, by default None
- if all_disable is True, disable_set will be ignored
- else, disable the order whose stock_id in disable_set
all_disable : bool, optional
wether to disable all order, by default False
"""
if all_disable is True:
self.disable_dict.update(self.enable_dict)
self.enable_dict.clear()
if disable_set is not None:
warnings.warn(f"`disable_set` is ignored because `all_disable` is set True")
else:
disable_set = set(disable_set)
for _stock_id in disable_set:
disable_order = self.enable_dict.get(_stock_id)
if disable_order is None:
raise ValueError(f"_stock_id {_stock_id} is not found in enable set")
self.disable_dict.update({_stock_id: disable_order})
self.enable_dict.pop(_stock_id)
def generator(self, only_enable=False, only_disable=False):
"""get order generator used for iteration
Parameters
----------
only_enable : bool, optional
wether to ignore disabled order, by default False
only_disable : bool, optional
wether to ignore enabled order, by default False
"""
if not only_disable and not only_enable:
yield from self.order_list
elif not only_disable:
yield from self.enable_dict.values()
elif not only_enable:
yield from self.disable_dict.values()
def get_order_list(self, only_enable=False, only_disable=False):
"""get the order list
Parameters
----------
only_enable : bool, optional
wether to ignore disabled order, by default False
only_disable : bool, optional
wether to ignore enabled order, by default False
Returns
-------
List[Order]
the order list
"""
if not only_disable and not only_enable:
return self.order_list
elif not only_disable:
return list(self.enable_dict.values())
elif not only_enable:
return list(self.disable_dict.values())
def update(self, trade_step, trade_len):
"""make the original strategy update the enabled status of orders."""
self.ori_strategy.update_trade_decision(self, trade_step, trade_len)

View File

@@ -6,6 +6,8 @@ import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...backtest.order import Order
from ...backtest.utils import TradeDecison
from .order_generator import OrderGenWInteract
@@ -244,7 +246,7 @@ class TopkDropoutStrategy(ModelStrategy):
factor=factor,
)
buy_order_list.append(buy_order)
return sell_order_list + buy_order_list
return TradeDecison(order_list=sell_order_list + buy_order_list, ori_strategy=self)
class WeightStrategyBase(ModelStrategy):
@@ -339,4 +341,4 @@ class WeightStrategyBase(ModelStrategy):
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
)
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)

View File

@@ -6,6 +6,8 @@ This order generator is for strategies based on WeightStrategyBase
"""
from ...backtest.position import Position
from ...backtest.exchange import Exchange
from ...backtest.utils import TradeDecison
import pandas as pd
import copy
@@ -125,7 +127,7 @@ class OrderGenWInteract(OrderGenerator):
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
)
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)
class OrderGenWOInteract(OrderGenerator):
@@ -189,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator):
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
)
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)

View File

@@ -9,7 +9,7 @@ from ...data.dataset.utils import convert_index_format
from ...strategy.base import BaseStrategy
from ...backtest.order import Order
from ...backtest.exchange import Exchange
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison
class TWAPStrategy(BaseStrategy):
@@ -17,7 +17,7 @@ class TWAPStrategy(BaseStrategy):
def __init__(
self,
outer_trade_decision: List[Order] = None,
outer_trade_decision: TradeDecison = None,
trade_exchange: Exchange = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
@@ -25,8 +25,8 @@ class TWAPStrategy(BaseStrategy):
"""
Parameters
----------
outer_trade_decision : List[Order]
the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy
outer_trade_decision : TradeDecison
the trade decison of outer strategy which this startegy relies
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
@@ -57,33 +57,37 @@ class TWAPStrategy(BaseStrategy):
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
"""
Parameters
----------
outer_trade_decision : List[Order], optional
outer_trade_decision : TradeDecison, optional
"""
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_amount = {}
for order in outer_trade_decision:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
outer_order_generator = outer_trade_decision.generator()
for order in outer_order_generator:
self.trade_amount[order.stock_id] = order.amount
def generate_trade_decision(self, execute_result=None):
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
# update outer trade decision
self.outer_trade_decision.update(trade_step, trade_len)
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[order.stock_id] -= order.deal_amount
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
order_list = []
for order in self.outer_trade_decision:
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
for order in outer_order_generator:
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
@@ -94,12 +98,12 @@ class TWAPStrategy(BaseStrategy):
# considering trade unit
if _amount_trade_unit is None:
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
# without considering trade unit
else:
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
_order_amount = (
@@ -108,12 +112,10 @@ class TWAPStrategy(BaseStrategy):
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
_order_amount = self.trade_amount[order.stock_id]
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
if _order_amount > 1e-5:
@@ -126,7 +128,7 @@ class TWAPStrategy(BaseStrategy):
factor=order.factor,
)
order_list.append(_order)
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)
class SBBStrategyBase(BaseStrategy):
@@ -140,7 +142,7 @@ class SBBStrategyBase(BaseStrategy):
def __init__(
self,
outer_trade_decision: List[Order] = None,
outer_trade_decision: TradeDecison = None,
trade_exchange: Exchange = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
@@ -148,8 +150,8 @@ class SBBStrategyBase(BaseStrategy):
"""
Parameters
----------
outer_trade_decision : List[Order]
the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase
outer_trade_decision : TradeDecison
the trade decison of outer strategy which this startegy relies
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
@@ -178,52 +180,57 @@ class SBBStrategyBase(BaseStrategy):
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
"""
Parameters
----------
outer_trade_decision : List[Order], optional
outer_trade_decision : TradeDecison, optional
"""
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_trend = {}
self.trade_amount = {}
# init the trade amount of order and predicted trade trend
for order in outer_trade_decision:
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
self.trade_amount[(order.stock_id, order.direction)] = order.amount
outer_order_generator = outer_trade_decision.generator()
for order in outer_order_generator:
self.trade_trend[order.stock_id] = self.TREND_MID
self.trade_amount[order.stock_id] = order.amount
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_trade_decision(self, execute_result=None):
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
# update outer trade decision
self.outer_trade_decision.update(trade_step, trade_len)
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[order.stock_id] -= order.deal_amount
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
order_list = []
# for each order in in self.outer_trade_decision
for order in self.outer_trade_decision:
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
for order in outer_order_generator:
# get the price trend
if trade_step % 2 == 0:
# in the first of two adjacent bars, predict the price trend
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
# in the second of two adjacent bars, use the trend predicted in the first one
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
_pred_trend = self.trade_trend[order.stock_id]
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
if trade_step % 2 == 0:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
self.trade_trend[order.stock_id] = _pred_trend
continue
# get amount of one trade unit
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
@@ -232,12 +239,12 @@ class SBBStrategyBase(BaseStrategy):
# considering trade unit
if _amount_trade_unit is None:
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
# without considering trade unit
else:
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
_order_amount = (
@@ -245,12 +252,12 @@ class SBBStrategyBase(BaseStrategy):
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
if self.trade_amount[order.stock_id] > 1e-5 and (
_order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
_order_amount = self.trade_amount[order.stock_id]
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
if _order_amount > 1e-5:
_order = Order(
@@ -268,13 +275,11 @@ class SBBStrategyBase(BaseStrategy):
# considering trade unit
if _amount_trade_unit is None:
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
_order_amount = (
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
)
_order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1)
# without considering trade unit
else:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
_order_amount = (
(trade_unit_cnt + trade_len - trade_step)
@@ -284,12 +289,12 @@ class SBBStrategyBase(BaseStrategy):
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
if self.trade_amount[order.stock_id] > 1e-5 and (
_order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
_order_amount = self.trade_amount[order.stock_id]
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
if _order_amount > 1e-5:
if trade_step % 2 == 0:
@@ -333,9 +338,9 @@ class SBBStrategyBase(BaseStrategy):
if trade_step % 2 == 0:
# in the first one of two adjacent bars, store the trend for the second one to use
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
self.trade_trend[order.stock_id] = _pred_trend
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)
class SBBStrategyEMA(SBBStrategyBase):
@@ -345,7 +350,7 @@ class SBBStrategyEMA(SBBStrategyBase):
def __init__(
self,
outer_trade_decision: List[Order] = None,
outer_trade_decision: TradeDecison = None,
instruments: Union[List, str] = "csi300",
freq: str = "day",
trade_exchange: Exchange = None,
@@ -425,7 +430,7 @@ class ACStrategy(BaseStrategy):
lamb: float = 1e-6,
eta: float = 2.5e-6,
window_size: int = 20,
outer_trade_decision: List[Order] = None,
outer_trade_decision: TradeDecison = None,
instruments: Union[List, str] = "csi300",
freq: str = "day",
trade_exchange: Exchange = None,
@@ -502,34 +507,38 @@ class ACStrategy(BaseStrategy):
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
"""
Parameters
----------
outer_trade_decision : List[Order], optional
outer_trade_decision : TradeDecison, optional
"""
super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_amount = {}
# init the trade amount of order and predicted trade trend
for order in outer_trade_decision:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
outer_order_generator = outer_trade_decision.generator()
for order in outer_order_generator:
self.trade_amount[order.stock_id] = order.amount
def generate_trade_decision(self, execute_result=None):
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
# update outer trade decision
self.outer_trade_decision.update(trade_step, trade_len)
# update the order amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[order.stock_id] -= order.deal_amount
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
order_list = []
for order in self.outer_trade_decision:
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
for order in outer_order_generator:
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
@@ -549,11 +558,11 @@ class ACStrategy(BaseStrategy):
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
if _amount_trade_unit is None:
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
else:
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
_order_amount = (
@@ -571,12 +580,10 @@ class ACStrategy(BaseStrategy):
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
_order_amount = self.trade_amount[order.stock_id]
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
if _order_amount > 1e-5:
@@ -589,4 +596,4 @@ class ACStrategy(BaseStrategy):
factor=order.factor,
)
order_list.append(_order)
return order_list
return TradeDecison(order_list=order_list, ori_strategy=self)

View File

@@ -7,7 +7,7 @@ from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison
class BaseStrategy:
@@ -15,14 +15,14 @@ class BaseStrategy:
def __init__(
self,
outer_trade_decision: object = None,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : object, optional
outer_trade_decision : TradeDecison, optional
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
- If the strategy is used to split trade decison, it will be used
- If the strategy is used for portfolio management, it can be ignored
@@ -84,6 +84,17 @@ class BaseStrategy:
"""
raise NotImplementedError("generate_trade_decision is not implemented!")
def update_trade_decision(self, trade_decison: TradeDecison, trade_step, trade_len):
"""update trade decision in each step of inner execution, this method enable all order
Parameters
----------
trade_decison : TradeDecison
the trade decison that will be updated
"""
if trade_step == 0:
trade_decison.enable(all_enable=True)
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""
@@ -92,7 +103,7 @@ class ModelStrategy(BaseStrategy):
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: object = None,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
@@ -128,7 +139,7 @@ class RLStrategy(BaseStrategy):
def __init__(
self,
policy,
outer_trade_decision: object = None,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
@@ -151,7 +162,7 @@ class RLIntStrategy(RLStrategy):
policy,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
outer_trade_decision: object = None,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,