1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

Merge branch 'nested_decision_exe' into nested_decision_exe

This commit is contained in:
you-n-g
2021-07-23 12:15:45 +08:00
committed by GitHub
12 changed files with 343 additions and 234 deletions

View File

@@ -160,7 +160,7 @@ class Account:
self.accum_info.add_return_value(profit) # note here do not consider cost
def update_order(self, order, trade_val, cost, trade_price):
if not self.is_port_metr_enabled():
if self.current.skip_update():
# TODO: supporting polymorphism for account
# updating order for infinite position is meaningless
return

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
from qlib.backtest.position import Position
import random
import logging
from typing import List, Tuple, Union, Callable, Iterable
@@ -281,6 +282,8 @@ class Exchange:
"""
Deal order when the actual transaction
the results section in `Order` will be changed.
:param order: Deal the order.
:param trade_account: Trade account to be updated after dealing the order.
:param position: position to be updated after dealing the order.
@@ -343,6 +346,7 @@ class Exchange:
`None`: if the stock is suspended `None` may be returned
`float`: return factor if the factor exists
"""
assert (start_time is not None and end_time is not None, "the time range must be given")
if stock_id not in self.quote.get_all_stock():
return None
return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last)
@@ -505,20 +509,56 @@ class Exchange:
)
return value
def get_amount_of_trade_unit(self, factor):
def _get_factor_or_raise_erorr(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
"""Please refer to the docs of get_amount_of_trade_unit"""
if factor is None:
if stock_id is not None and start_time is not None and end_time is not None:
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
else:
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
return factor
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
"""
get the trade unit of amount based on **factor**
the factor can be given directly or calculated in given time range and stock id.
`factor` has higher priority than `stock_id`, `start_time` and `end_time`
Parameters
----------
factor : float
the adjusted factor
stock_id : str
the id of the stock
start_time :
the start time of trading range
end_time :
the end time of trading range
"""
if not self.trade_w_adj_price and self.trade_unit is not None:
factor = self._get_factor_or_raise_erorr(
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
)
return self.trade_unit / factor
else:
return None
def round_amount_by_trade_unit(self, deal_amount, factor):
def round_amount_by_trade_unit(
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
):
"""Parameter
Please refer to the docs of get_amount_of_trade_unit
deal_amount : float, adjusted amount
factor : float, adjusted factor
return : float, real amount
"""
if not self.trade_w_adj_price and self.trade_unit is not None:
# the minimal amount is 1. Add 0.1 for solving precision problem.
factor = self._get_factor_or_raise_erorr(
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
)
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
@@ -529,7 +569,7 @@ class Exchange:
else:
return deal_amount
def _calc_trade_info_by_order(self, order, position):
def _calc_trade_info_by_order(self, order, position: Position):
"""
Calculation of trade info
@@ -541,6 +581,7 @@ class Exchange:
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
if order.direction == Order.SELL:
# sell
if position is not None:

View File

@@ -1,5 +1,6 @@
from abc import abstractclassmethod, abstractmethod
import copy
from qlib.log import get_module_logger
from types import GeneratorType
from qlib.backtest.account import Account
import warnings
@@ -102,7 +103,10 @@ class BaseExecutor:
self.track_data = track_data
self._trade_exchange = trade_exchange
self.level_infra = LevelInfrastructure()
self.level_infra.reset_infra(common_infra=common_infra)
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
if common_infra is None:
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
def reset_common_infra(self, common_infra):
"""
@@ -239,7 +243,7 @@ class BaseExecutor:
# Some concrete executor don't have inner decisions
res, kwargs = obj
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
# Account will not be changed in this function
self.trade_account.update_bar_end(
trade_start_time,
@@ -332,7 +336,7 @@ class NestedExecutor(BaseExecutor):
self.inner_strategy.reset_common_infra(common_infra)
def _init_sub_trading(self, trade_decision):
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.inner_executor.get_level_infra()
self.level_infra.set_sub_level_infra(sub_level_infra)
@@ -379,8 +383,8 @@ class NestedExecutor(BaseExecutor):
)
trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information
# NOTE sub_cal.get_cur_step_time() must be called before collect_data in case of step shifting
decision_list.append((_inner_trade_decision, *sub_cal.get_cur_step_time()))
# NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))
# NOTE: Trade Calendar will step forward in the follow line
_inner_execute_result = yield from self.inner_executor.collect_data(
@@ -478,7 +482,7 @@ class SimulatorExecutor(BaseExecutor):
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
trade_start_time, _ = self.trade_calendar.get_cur_step_time()
trade_start_time, _ = self.trade_calendar.get_step_time()
execute_result = []
for order in self._get_order_iterator(trade_decision):
@@ -491,30 +495,22 @@ class SimulatorExecutor(BaseExecutor):
execute_result.append((order, trade_val, trade_cost, trade_price))
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
order.amount,
order.deal_amount,
order.factor,
trade_val,
)
)
action = "sell"
else:
print(
"[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
order.amount,
order.deal_amount,
order.factor,
trade_val,
)
action = "buy"
print(
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cach {:.2f}.".format(
trade_start_time,
action,
order.stock_id,
trade_price,
order.amount,
order.deal_amount,
order.factor,
trade_val,
self.trade_account.get_cash(),
)
)
else:
if self.verbose:
print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id))

View File

@@ -3,7 +3,8 @@
# TODO: rename it with decision.py
from __future__ import annotations
from enum import IntEnum
from qlib.utils.time import concat_date_time
from qlib.data.data import Cal
from qlib.utils.time import concat_date_time, epsilon_change
from qlib.log import get_module_logger
# try to fix circular imports when enabling type hints
@@ -41,16 +42,24 @@ class Order:
presents the weight factor assigned in Exchange()
"""
# 1) time invariant values
# - they are set by users and is time-invariant.
stock_id: str
amount: float # `amount` is a non-negative value
amount: float # `amount` is a non-negative and adjusted value
direction: int
# 2) time variant values:
# - Users may want to set these values when using lower level APIs
# - If users don't, TradeDecisionWO will help users to set them
# The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)
start_time: pd.Timestamp
end_time: pd.Timestamp
direction: int
factor: float
# 3) results
# - users should not care about these values
# - they are set by the backtest system after finishing the results.
deal_amount: float = field(init=False) # `deal_amount` is a non-negative value
factor: float = field(init=False)
# FIXME:
# for compatible now.
@@ -127,8 +136,8 @@ class OrderHelper:
code: str,
amount: float,
direction: OrderDir,
start_time: Union[str, pd.Timestamp],
end_time: Union[str, pd.Timestamp],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
) -> Order:
"""
help to create a order
@@ -143,9 +152,9 @@ class OrderHelper:
**adjusted trading amount**
direction : OrderDir
trading direction
start_time : Union[str, pd.Timestamp]
start_time : Union[str, pd.Timestamp] (optional)
The interval of the order which belongs to
end_time : Union[str, pd.Timestamp]
end_time : Union[str, pd.Timestamp] (optional)
The interval of the order which belongs to
Returns
@@ -153,15 +162,17 @@ class OrderHelper:
Order:
The created order
"""
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
if start_time is not None:
start_time = pd.Timestamp(start_time)
if end_time is not None:
end_time = pd.Timestamp(end_time)
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
return Order(
stock_id=code,
amount=amount,
start_time=start_time,
end_time=end_time,
direction=direction,
factor=self.exchange.get_factor(code, start_time, end_time),
)
@@ -291,6 +302,7 @@ class BaseTradeDecision:
"""
self.strategy = strategy
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
if isinstance(trade_range, Tuple):
# for Tuple[int, int]
@@ -406,6 +418,62 @@ class BaseTradeDecision:
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
return _start_idx, _end_idx
def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = False) -> Tuple[int, int]:
"""
get the range limit based on data calendar
NOTE: it is **total** range limit instead of a single step
The following assumptions are made
1) The frequency of the exchange in common_infra is the same as the data calendar
2) Users want the index mod by **day** (i.e. 240 min)
Parameters
----------
rtype: str
- "full": return the full limitation of the deicsion in the day
- "step": return the limitation of current step
raise_error: bool
True: raise error if no trade_range is set
False: return full trade calendar.
It is useful in following cases
- users want to follow the order specific trading time range when decision level trade range is not
available. Raising NotImplementedError to indicates that range limit is not available
Returns
-------
Tuple[int, int]:
the range limit in data calendar
Raises
------
NotImplementedError:
If the following criteria meet
1) the decision can't provide a unified start and end
2) raise_error is True
"""
# potential performance issue
day_start = pd.Timestamp(self.start_time.date())
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
freq = self.strategy.trade_exchange.freq
_, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq)
if self.trade_range is None:
if raise_error:
raise NotImplementedError(f"There is no trade_range in this case")
else:
return 0, day_end_idx - day_start_idx
else:
if rtype == "full":
val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)
elif rtype == "step":
val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)
else:
raise ValueError(f"This type of input {rtype} is not supported")
_, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)
return start_idx - day_start_idx, end_index - day_start_idx
def empty(self) -> bool:
for obj in self.get_decision():
if isinstance(obj, Order):
@@ -452,9 +520,15 @@ class TradeDecisionWO(BaseTradeDecision):
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
start, end = strategy.trade_calendar.get_step_time()
for o in order_list:
if o.start_time is None:
o.start_time = start
if o.end_time is None:
o.end_time = end
def get_decision(self) -> List[object]:
return self.order_list
def __repr__(self) -> str:
return f"strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"

View File

@@ -351,7 +351,10 @@ class Indicator:
trade_exchange: Exchange,
pa_config: dict = {},
):
"""Get the base volume and price information"""
"""
Get the base volume and price information
All the base price values are rooted from this function
"""
agg = pa_config.get("agg", "twap").lower()
price = pa_config.get("price", "deal_price").lower()
@@ -374,10 +377,12 @@ class Indicator:
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
# price_s = price_s.mask(np.isclose(price_s, 0))
price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values.
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
volume_s = volume_s.reindex(price_s.index)
elif agg == "twap":
volume_s = pd.Series(1, index=price_s.index)
else:

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations
import bisect
from qlib.utils.time import epsilon_change
from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set
if TYPE_CHECKING:
@@ -22,7 +23,11 @@ class TradeCalendarManager:
"""
def __init__(
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
self,
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
level_infra: "LevelInfrastructure" = None,
):
"""
Parameters
@@ -36,6 +41,7 @@ class TradeCalendarManager:
closed end of the trade time range, by default None
If `end_time` is None, it must be reset before trading.
"""
self.level_infra = level_infra
self.reset(freq=freq, start_time=start_time, end_time=end_time)
def reset(self, freq, start_time, end_time):
@@ -82,19 +88,19 @@ class TradeCalendarManager:
def get_trade_step(self):
return self.trade_step
def get_step_time(self, trade_step=0, shift=0):
def get_step_time(self, trade_step=None, shift=0):
"""
Get the left and right endpoints of the trade_step'th trading interval
About the endpoints:
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
- The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
# - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
Parameters
----------
trade_step : int, optional
the number of trading step finished, by default 0
the number of trading step finished, by default None to indicate current step
shift : int, optional
shift bars , by default 0
@@ -105,15 +111,43 @@ class TradeCalendarManager:
- If shift > 0, return the trading time range of the earlier shift bars
- If shift < 0, return the trading time range of the later shift bar
"""
if trade_step is None:
trade_step = self.get_trade_step()
trade_step = trade_step - shift
calendar_index = self.start_index + trade_step
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
def get_cur_step_time(self):
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
"""
get current step time
get the calendar range
The following assumptions are made
1) The frequency of the exchange in common_infra is the same as the data calendar
2) Users want the **data index** mod by **day** (i.e. 240 min)
Parameters
----------
rtype: str
- "full": return the full limitation of the deicsion in the day
- "step": return the limitation of current step
Returns
-------
Tuple[int, int]:
"""
return self.get_step_time(self.get_trade_step())
# potential performance issue
day_start = pd.Timestamp(self.start_time.date())
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
_, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq)
if rtype == "full":
_, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq)
elif rtype == "step":
_, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq)
else:
raise ValueError(f"This type of input {rtype} is not supported")
return start_idx - day_start_idx, end_index - day_start_idx
def get_all_time(self):
"""Get the start_time and end_time for trading"""
@@ -147,7 +181,7 @@ class TradeCalendarManager:
return clip(left), clip(right)
def __repr__(self) -> str:
return f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
class BaseInfrastructure:
@@ -198,14 +232,16 @@ class LevelInfrastructure(BaseInfrastructure):
sub_level_infra:
- **NOTE**: this will only work after _init_sub_trading !!!
"""
return ["trade_calendar", "sub_level_infra"]
return ["trade_calendar", "sub_level_infra", "common_infra"]
def reset_cal(self, freq, start_time, end_time):
"""reset trade calendar manager"""
if self.has("trade_calendar"):
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
else:
self.reset_infra(trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time))
self.reset_infra(
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
)
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
"""this will make the calendar access easier when acrossing multi-levels"""

View File

@@ -73,20 +73,20 @@ def indicator_analysis(df, method="mean"):
Parameters
----------
df : pandas.DataFrame
columns: like ['pa', 'pos', 'ffr', 'amount', 'value'].
columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value'].
Necessary fields:
- 'pa' is the price advantage in trade indicators
- 'pos' is the positive rate in trade indicators
- 'ffr' is the fulfill rate in trade indicators
Optional fields:
- 'amount' is the total deal amount, only necessary when method is 'amount_weighted'
- 'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted'
- 'value' is the total trade value, only necessary when method is 'value_weighted'
index: Index(datetime)
method : str, optional
statistics method of pa/ffr, by default "mean"
- if method is 'mean', count the mean statistical value of each trade indicator
- if method is 'amount_weighted', count the amount weighted mean statistical value of each trade indicator
- if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator
- if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
Note: statistics method of pos is always "mean"
@@ -97,7 +97,7 @@ def indicator_analysis(df, method="mean"):
"""
weights_dict = {
"mean": df["count"],
"amount_weighted": df["amount"].abs(),
"amount_weighted": df["deal_amount"].abs(),
"value_weighted": df["value"].abs(),
}
if method not in weights_dict:

View File

@@ -64,7 +64,7 @@ class TopkDropoutStrategy(ModelStrategy):
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop
@@ -73,22 +73,6 @@ class TopkDropoutStrategy(ModelStrategy):
self.risk_degree = risk_degree
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
@@ -210,7 +194,6 @@ class TopkDropoutStrategy(ModelStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=Order.SELL, # 0 for sell, 1 for buy
factor=factor,
)
# is order executable
if self.trade_exchange.check_order(sell_order):
@@ -247,7 +230,6 @@ class TopkDropoutStrategy(ModelStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=Order.BUY, # 1 for buy
factor=factor,
)
buy_order_list.append(buy_order)
return TradeDecisionWO(sell_order_list + buy_order_list, self)
@@ -278,28 +260,12 @@ class WeightStrategyBase(ModelStrategy):
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(WeightStrategyBase, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_step=None):
"""get_risk_degree

View File

@@ -20,48 +20,6 @@ from qlib.backtest.utils import get_start_end_idx
class TWAPStrategy(BaseStrategy):
"""TWAP Strategy for trading"""
def __init__(
self,
outer_trade_decision: BaseTradeDecision = None,
trade_exchange: Exchange = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : BaseTradeDecision
the trade decision of outer strategy which this startegy relies
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
- It allowes different trade_exchanges is used in different executions.
- For example:
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(TWAPStrategy, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
)
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(TWAPStrategy, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
"""
Parameters
@@ -105,7 +63,9 @@ class TWAPStrategy(BaseStrategy):
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
continue
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
)
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
@@ -141,7 +101,6 @@ class TWAPStrategy(BaseStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
factor=order.factor,
)
order_list.append(_order)
return TradeDecisionWO(order_list=order_list, strategy=self)
@@ -161,46 +120,6 @@ class SBBStrategyBase(BaseStrategy):
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
outer_trade_decision: BaseTradeDecision = None,
trade_exchange: Exchange = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : BaseTradeDecision
the trade decision of outer strategy which this startegy relies
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
- It allowes different trade_exchanges is used in different executions.
- For example:
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(SBBStrategyBase, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
)
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
"""
Parameters
@@ -250,7 +169,9 @@ class SBBStrategyBase(BaseStrategy):
self.trade_trend[order.stock_id] = _pred_trend
continue
# get amount of one trade unit
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
)
if _pred_trend == self.TREND_MID:
_order_amount = None
# considering trade unit
@@ -283,7 +204,6 @@ class SBBStrategyBase(BaseStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction,
factor=order.factor,
)
order_list.append(_order)
@@ -330,7 +250,6 @@ class SBBStrategyBase(BaseStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
factor=order.factor,
)
order_list.append(_order)
else:
@@ -349,7 +268,6 @@ class SBBStrategyBase(BaseStrategy):
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
factor=order.factor,
)
order_list.append(_order)
@@ -395,7 +313,9 @@ class SBBStrategyEMA(SBBStrategyBase):
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs)
super(SBBStrategyEMA, self).__init__(
outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
)
def _reset_signal(self):
trade_len = self.trade_calendar.get_trade_len()
@@ -417,14 +337,8 @@ class SBBStrategyEMA(SBBStrategyBase):
reset level-shared infra
- After reset the trade calendar, the signal will be changed
"""
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if level_infra.has("trade_calendar"):
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
super().reset_level_infra(level_infra)
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
# if no signal, return mid trend
@@ -484,10 +398,9 @@ class ACStrategy(BaseStrategy):
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
super(ACStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
if trade_exchange is not None:
self.trade_exchange = trade_exchange
super(ACStrategy, self).__init__(
outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
)
def _reset_signal(self):
trade_len = self.trade_calendar.get_trade_len()
@@ -506,33 +419,13 @@ class ACStrategy(BaseStrategy):
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(ACStrategy, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset_level_infra(self, level_infra):
"""
reset level-shared infra
- After reset the trade calendar, the signal will be changed
"""
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if level_infra.has("trade_calendar"):
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
super().reset_level_infra(level_infra)
self._reset_signal()
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
"""
@@ -578,7 +471,9 @@ class ACStrategy(BaseStrategy):
if sig_sam is None or np.isnan(sig_sam):
# no signal, TWAP
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
)
if _amount_trade_unit is None:
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
@@ -599,7 +494,9 @@ class ACStrategy(BaseStrategy):
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
) / np.sinh(kappa * trade_len)
_order_amount = order.amount * amount_ratio
_order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, order.factor)
_order_amount = self.trade_exchange.round_amount_by_trade_unit(
_order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
)
if order.direction == order.SELL:
# sell all amount at last
@@ -673,8 +570,6 @@ class RandomOrderStrategy(BaseStrategy):
.create(
code=stock_id,
amount=volume * self.volume_ratio,
start_time=step_time_start,
end_time=step_time_end,
direction=self.direction,
)
)
@@ -734,9 +629,7 @@ class FileOrderStrategy(BaseStrategy):
execute_result will be ignored in FileOrderStrategy
"""
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
tc = self.trade_calendar
step = tc.get_trade_step()
start, end = tc.get_step_time(step)
start, _ = self.trade_calendar.get_step_time()
# CONVERSION: the bar is indexed by the time
try:
df = self.order_df.loc(axis=0)[start]
@@ -750,8 +643,6 @@ class FileOrderStrategy(BaseStrategy):
code=idx,
amount=row["amount"],
direction=Order.parse_dir(row["direction"]),
start_time=start,
end_time=end,
)
)
return TradeDecisionWO(order_list, self, self.trade_range)

View File

@@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import List, Union
from typing import List, Tuple, Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
@@ -22,6 +23,7 @@ class BaseStrategy:
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
trade_exchange: Exchange = None,
):
"""
Parameters
@@ -34,9 +36,18 @@ class BaseStrategy:
level shared infrastructure for backtesting, including trade calendar
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
- It allowes different trade_exchanges is used in different executions.
- For example:
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
self._trade_exchange = trade_exchange
@property
def trade_calendar(self) -> TradeCalendarManager:
@@ -46,6 +57,11 @@ class BaseStrategy:
def trade_position(self) -> BasePosition:
return self.common_infra.get("trade_account").current
@property
def trade_exchange(self) -> Exchange:
"""get trade exchange in a prioritized order"""
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
def reset_level_infra(self, level_infra: LevelInfrastructure):
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
@@ -69,6 +85,24 @@ class BaseStrategy:
- reset `level_infra`, used to reset trade calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
- reset `outer_trade_decision`, used to make split decision
**NOTE**:
split this function into `reset` and `_reset` will make following cases more convenient
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
when initialization
"""
self._reset(
level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
)
def _reset(
self,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision=None,
):
"""
Please refer to the docs of `reset`
"""
if level_infra is not None:
self.reset_level_infra(level_infra)
@@ -124,6 +158,36 @@ class BaseStrategy:
# NOTE: normally, user should do something to the strategy due to the change of outer decision
raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method")
# helper methods: not necessary but for convenience
def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]:
"""
return data calendar's available decision range for `self` strategy
the range consider following factors
- data calendar in the charge of `self` strategy
- trading range limitation from the decision of outer strategy
related methods
- TradeCalendarManager.get_data_cal_range
- BaseTradeDecision.get_data_cal_range_limit
Parameters
----------
rtype: str
- "full": return the available data index range of the strategy from `start_time` to `end_time`
- "step": return the available data index range of the strategy of current step
Returns
-------
Tuple[int, int]:
the available range both sides are closed
"""
cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype)
if self.outer_trade_decision is None:
raise ValueError(f"There is not limitation for strategy {self}")
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""

View File

@@ -210,10 +210,13 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
the class object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
if isinstance(config["class"], str):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
# raise AttributeError
klass = getattr(module, config["class"])
else:
klass = config["class"] # the class type itself is passed in
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
@@ -235,11 +238,17 @@ def init_instance_by_config(
----------
config : Union[str, dict, object]
dict example.
case 1)
{
'class': 'ClassName',
'kwargs': dict, # It is optional. {} will be used if not given
'model_path': path, # It is optional if module is given
}
case 2)
{
'class': <The class it self>,
'kwargs': dict, # It is optional. {} will be used if not given
}
str example.
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'

View File

@@ -160,5 +160,32 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp:
return concat_date_time(date, new_time)
def epsilon_change(datetime: pd.Timestamp, direction: str = "backward") -> pd.Timestamp:
"""
change the time by infinitely small quantity.
Parameters
----------
datetime : pd.Timestamp
the original time
direction : str
the direction the time are going to
- "backward" for going to history
- "forward" for going to the future
Returns
-------
pd.Timestamp:
the shifted time
"""
if direction == "backward":
return datetime - pd.Timedelta(seconds=1)
elif direction == "forward":
return datetime + pd.Timedelta(seconds=1)
else:
raise ValueError("Wrong input")
if __name__ == "__main__":
print(get_day_min_idx_range("8:30", "14:59", "10min"))