mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
support trade decision update
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from .order import Order
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.resam import parse_freq
|
||||
@@ -135,7 +135,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : object
|
||||
trade_decision : TradeDecison
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -149,7 +149,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : object
|
||||
trade_decision : TradeDecison
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -352,7 +352,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
for order in trade_decision:
|
||||
order_generator = trade_decision.generator()
|
||||
for order in order_generator:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from re import L
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
from typing import Union, List, Set
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
from ..data.data import Cal
|
||||
@@ -145,3 +146,118 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_calendar"]
|
||||
|
||||
|
||||
class TradeDecison:
|
||||
"""trade decison that made by strategy"""
|
||||
|
||||
def __init__(self, order_list, ori_strategy, init_enable=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
order_list : list
|
||||
the order list
|
||||
ori_strategy : BaseStrategy
|
||||
the original strategy that make the decison
|
||||
init_enable : bool, optional
|
||||
wether to enable order initially, default by False
|
||||
"""
|
||||
self.order_list = order_list
|
||||
self.ori_strategy = ori_strategy
|
||||
if init_enable:
|
||||
self.enable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
self.disable_dict = dict()
|
||||
else:
|
||||
self.enable_dict = dict()
|
||||
self.disable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
|
||||
def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False):
|
||||
"""enable order set
|
||||
Parameters
|
||||
----------
|
||||
enable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be enabled, by default None
|
||||
- if all_enable is True, enable_set will be ignored
|
||||
- else, enable the order whose stock_id in enable_set
|
||||
all_enable : bool, optional
|
||||
wether to enable all order, by default False
|
||||
"""
|
||||
if all_enable is True:
|
||||
self.enable_dict.update(self.disable_dict)
|
||||
self.disable_dict.clear()
|
||||
if enable_set is not None:
|
||||
warnings.warn(f"`enable_set` is ignored because `all_enable` is set True")
|
||||
else:
|
||||
enable_set = set(enable_set)
|
||||
for _stock_id in enable_set:
|
||||
enable_order = self.disable_dict.get(_stock_id)
|
||||
if enable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in disable set")
|
||||
self.enable_order.update({_stock_id: enable_order})
|
||||
self.disable_dict.pop(_stock_id)
|
||||
|
||||
def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False):
|
||||
"""disable order set
|
||||
Parameters
|
||||
----------
|
||||
disable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be disabled, by default None
|
||||
- if all_disable is True, disable_set will be ignored
|
||||
- else, disable the order whose stock_id in disable_set
|
||||
all_disable : bool, optional
|
||||
wether to disable all order, by default False
|
||||
"""
|
||||
if all_disable is True:
|
||||
self.disable_dict.update(self.enable_dict)
|
||||
self.enable_dict.clear()
|
||||
if disable_set is not None:
|
||||
warnings.warn(f"`disable_set` is ignored because `all_disable` is set True")
|
||||
else:
|
||||
disable_set = set(disable_set)
|
||||
for _stock_id in disable_set:
|
||||
disable_order = self.enable_dict.get(_stock_id)
|
||||
if disable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in enable set")
|
||||
self.disable_dict.update({_stock_id: disable_order})
|
||||
self.enable_dict.pop(_stock_id)
|
||||
|
||||
def generator(self, only_enable=False, only_disable=False):
|
||||
"""get order generator used for iteration
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
yield from self.order_list
|
||||
elif not only_disable:
|
||||
yield from self.enable_dict.values()
|
||||
elif not only_enable:
|
||||
yield from self.disable_dict.values()
|
||||
|
||||
def get_order_list(self, only_enable=False, only_disable=False):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
return self.order_list
|
||||
elif not only_disable:
|
||||
return list(self.enable_dict.values())
|
||||
elif not only_enable:
|
||||
return list(self.disable_dict.values())
|
||||
|
||||
def update(self, trade_step, trade_len):
|
||||
"""make the original strategy update the enabled status of orders."""
|
||||
self.ori_strategy.update_trade_decision(self, trade_step, trade_len)
|
||||
|
||||
@@ -6,6 +6,8 @@ import pandas as pd
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.utils import TradeDecison
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
@@ -244,7 +246,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
return sell_order_list + buy_order_list
|
||||
return TradeDecison(order_list=sell_order_list + buy_order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
@@ -339,4 +341,4 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
@@ -6,6 +6,8 @@ This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import TradeDecison
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
@@ -125,7 +127,7 @@ class OrderGenWInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
@@ -189,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
@@ -9,7 +9,7 @@ from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
|
||||
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
@@ -17,7 +17,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -25,8 +25,8 @@ class TWAPStrategy(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy
|
||||
outer_trade_decision : TradeDecison
|
||||
the trade decison of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -57,33 +57,37 @@ class TWAPStrategy(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
for order in outer_trade_decision:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
outer_order_generator = outer_trade_decision.generator()
|
||||
for order in outer_order_generator:
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision:
|
||||
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
|
||||
for order in outer_order_generator:
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
@@ -94,12 +98,12 @@ class TWAPStrategy(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
# without considering trade unit
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
_order_amount = (
|
||||
@@ -108,12 +112,10 @@ class TWAPStrategy(BaseStrategy):
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
|
||||
@@ -126,7 +128,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class SBBStrategyBase(BaseStrategy):
|
||||
@@ -140,7 +142,7 @@ class SBBStrategyBase(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -148,8 +150,8 @@ class SBBStrategyBase(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase
|
||||
outer_trade_decision : TradeDecison
|
||||
the trade decison of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -178,52 +180,57 @@ class SBBStrategyBase(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_trend = {}
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in outer_trade_decision:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
outer_order_generator = outer_trade_decision.generator()
|
||||
for order in outer_order_generator:
|
||||
self.trade_trend[order.stock_id] = self.TREND_MID
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
order_list = []
|
||||
# for each order in in self.outer_trade_decision
|
||||
for order in self.outer_trade_decision:
|
||||
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
|
||||
for order in outer_order_generator:
|
||||
# get the price trend
|
||||
if trade_step % 2 == 0:
|
||||
# in the first of two adjacent bars, predict the price trend
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
# in the second of two adjacent bars, use the trend predicted in the first one
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
_pred_trend = self.trade_trend[order.stock_id]
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
if trade_step % 2 == 0:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
continue
|
||||
# get amount of one trade unit
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
@@ -232,12 +239,12 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
# without considering trade unit
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
@@ -245,12 +252,12 @@ class SBBStrategyBase(BaseStrategy):
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
_order = Order(
|
||||
@@ -268,13 +275,11 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
)
|
||||
_order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1)
|
||||
# without considering trade unit
|
||||
else:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step)
|
||||
@@ -284,12 +289,12 @@ class SBBStrategyBase(BaseStrategy):
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
if trade_step % 2 == 0:
|
||||
@@ -333,9 +338,9 @@ class SBBStrategyBase(BaseStrategy):
|
||||
|
||||
if trade_step % 2 == 0:
|
||||
# in the first one of two adjacent bars, store the trend for the second one to use
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
@@ -345,7 +350,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -425,7 +430,7 @@ class ACStrategy(BaseStrategy):
|
||||
lamb: float = 1e-6,
|
||||
eta: float = 2.5e-6,
|
||||
window_size: int = 20,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -502,34 +507,38 @@ class ACStrategy(BaseStrategy):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
"""
|
||||
super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in outer_trade_decision:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
outer_order_generator = outer_trade_decision.generator()
|
||||
for order in outer_order_generator:
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision:
|
||||
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
|
||||
for order in outer_order_generator:
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
@@ -549,11 +558,11 @@ class ACStrategy(BaseStrategy):
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
@@ -571,12 +580,10 @@ class ACStrategy(BaseStrategy):
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
|
||||
@@ -589,4 +596,4 @@ class ACStrategy(BaseStrategy):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return order_list
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -15,14 +15,14 @@ class BaseStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : object, optional
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
@@ -84,6 +84,17 @@ class BaseStrategy:
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
def update_trade_decision(self, trade_decison: TradeDecison, trade_step, trade_len):
|
||||
"""update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decison : TradeDecison
|
||||
the trade decison that will be updated
|
||||
"""
|
||||
if trade_step == 0:
|
||||
trade_decison.enable(all_enable=True)
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
@@ -92,7 +103,7 @@ class ModelStrategy(BaseStrategy):
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -128,7 +139,7 @@ class RLStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -151,7 +162,7 @@ class RLIntStrategy(RLStrategy):
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
|
||||
Reference in New Issue
Block a user