mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
Merge branch 'nested_decision_exe' into nested_decision_exe
This commit is contained in:
@@ -160,7 +160,7 @@ class Account:
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
if not self.is_port_metr_enabled():
|
||||
if self.current.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from qlib.backtest.position import Position
|
||||
import random
|
||||
import logging
|
||||
from typing import List, Tuple, Union, Callable, Iterable
|
||||
@@ -281,6 +282,8 @@ class Exchange:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
|
||||
the results section in `Order` will be changed.
|
||||
|
||||
:param order: Deal the order.
|
||||
:param trade_account: Trade account to be updated after dealing the order.
|
||||
:param position: position to be updated after dealing the order.
|
||||
@@ -343,6 +346,7 @@ class Exchange:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
assert (start_time is not None and end_time is not None, "the time range must be given")
|
||||
if stock_id not in self.quote.get_all_stock():
|
||||
return None
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last)
|
||||
@@ -505,20 +509,56 @@ class Exchange:
|
||||
)
|
||||
return value
|
||||
|
||||
def get_amount_of_trade_unit(self, factor):
|
||||
def _get_factor_or_raise_erorr(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
"""Please refer to the docs of get_amount_of_trade_unit"""
|
||||
if factor is None:
|
||||
if stock_id is not None and start_time is not None and end_time is not None:
|
||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
"""
|
||||
get the trade unit of amount based on **factor**
|
||||
|
||||
the factor can be given directly or calculated in given time range and stock id.
|
||||
`factor` has higher priority than `stock_id`, `start_time` and `end_time`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
factor : float
|
||||
the adjusted factor
|
||||
stock_id : str
|
||||
the id of the stock
|
||||
start_time :
|
||||
the start time of trading range
|
||||
end_time :
|
||||
the end time of trading range
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
factor = self._get_factor_or_raise_erorr(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
)
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
|
||||
def round_amount_by_trade_unit(self, deal_amount, factor):
|
||||
def round_amount_by_trade_unit(
|
||||
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
|
||||
):
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
|
||||
deal_amount : float, adjusted amount
|
||||
factor : float, adjusted factor
|
||||
return : float, real amount
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
factor = self._get_factor_or_raise_erorr(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
)
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
@@ -529,7 +569,7 @@ class Exchange:
|
||||
else:
|
||||
return deal_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position):
|
||||
def _calc_trade_info_by_order(self, order, position: Position):
|
||||
"""
|
||||
Calculation of trade info
|
||||
|
||||
@@ -541,6 +581,7 @@ class Exchange:
|
||||
"""
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractclassmethod, abstractmethod
|
||||
import copy
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
@@ -102,7 +103,10 @@ class BaseExecutor:
|
||||
self.track_data = track_data
|
||||
self._trade_exchange = trade_exchange
|
||||
self.level_infra = LevelInfrastructure()
|
||||
self.level_infra.reset_infra(common_infra=common_infra)
|
||||
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
|
||||
if common_infra is None:
|
||||
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -239,7 +243,7 @@ class BaseExecutor:
|
||||
# Some concrete executor don't have inner decisions
|
||||
res, kwargs = obj
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
|
||||
# Account will not be changed in this function
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
@@ -332,7 +336,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
self.level_infra.set_sub_level_infra(sub_level_infra)
|
||||
@@ -379,8 +383,8 @@ class NestedExecutor(BaseExecutor):
|
||||
)
|
||||
trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information
|
||||
|
||||
# NOTE sub_cal.get_cur_step_time() must be called before collect_data in case of step shifting
|
||||
decision_list.append((_inner_trade_decision, *sub_cal.get_cur_step_time()))
|
||||
# NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
|
||||
decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
@@ -478,7 +482,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
|
||||
trade_start_time, _ = self.trade_calendar.get_cur_step_time()
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
@@ -491,30 +495,22 @@ class SimulatorExecutor(BaseExecutor):
|
||||
execute_result.append((order, trade_val, trade_cost, trade_price))
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.amount,
|
||||
order.deal_amount,
|
||||
order.factor,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
action = "sell"
|
||||
else:
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.amount,
|
||||
order.deal_amount,
|
||||
order.factor,
|
||||
trade_val,
|
||||
)
|
||||
action = "buy"
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cach {:.2f}.".format(
|
||||
trade_start_time,
|
||||
action,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.amount,
|
||||
order.deal_amount,
|
||||
order.factor,
|
||||
trade_val,
|
||||
self.trade_account.get_cash(),
|
||||
)
|
||||
|
||||
)
|
||||
else:
|
||||
if self.verbose:
|
||||
print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from qlib.utils.time import concat_date_time
|
||||
from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
@@ -41,16 +42,24 @@ class Order:
|
||||
presents the weight factor assigned in Exchange()
|
||||
"""
|
||||
|
||||
# 1) time invariant values
|
||||
# - they are set by users and is time-invariant.
|
||||
stock_id: str
|
||||
amount: float # `amount` is a non-negative value
|
||||
amount: float # `amount` is a non-negative and adjusted value
|
||||
direction: int
|
||||
|
||||
# 2) time variant values:
|
||||
# - Users may want to set these values when using lower level APIs
|
||||
# - If users don't, TradeDecisionWO will help users to set them
|
||||
# The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)
|
||||
start_time: pd.Timestamp
|
||||
end_time: pd.Timestamp
|
||||
|
||||
direction: int
|
||||
factor: float
|
||||
# 3) results
|
||||
# - users should not care about these values
|
||||
# - they are set by the backtest system after finishing the results.
|
||||
deal_amount: float = field(init=False) # `deal_amount` is a non-negative value
|
||||
factor: float = field(init=False)
|
||||
|
||||
# FIXME:
|
||||
# for compatible now.
|
||||
@@ -127,8 +136,8 @@ class OrderHelper:
|
||||
code: str,
|
||||
amount: float,
|
||||
direction: OrderDir,
|
||||
start_time: Union[str, pd.Timestamp],
|
||||
end_time: Union[str, pd.Timestamp],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
) -> Order:
|
||||
"""
|
||||
help to create a order
|
||||
@@ -143,9 +152,9 @@ class OrderHelper:
|
||||
**adjusted trading amount**
|
||||
direction : OrderDir
|
||||
trading direction
|
||||
start_time : Union[str, pd.Timestamp]
|
||||
start_time : Union[str, pd.Timestamp] (optional)
|
||||
The interval of the order which belongs to
|
||||
end_time : Union[str, pd.Timestamp]
|
||||
end_time : Union[str, pd.Timestamp] (optional)
|
||||
The interval of the order which belongs to
|
||||
|
||||
Returns
|
||||
@@ -153,15 +162,17 @@ class OrderHelper:
|
||||
Order:
|
||||
The created order
|
||||
"""
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
if start_time is not None:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if end_time is not None:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.exchange.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
|
||||
@@ -291,6 +302,7 @@ class BaseTradeDecision:
|
||||
|
||||
"""
|
||||
self.strategy = strategy
|
||||
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
if isinstance(trade_range, Tuple):
|
||||
# for Tuple[int, int]
|
||||
@@ -406,6 +418,62 @@ class BaseTradeDecision:
|
||||
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
|
||||
return _start_idx, _end_idx
|
||||
|
||||
def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = False) -> Tuple[int, int]:
|
||||
"""
|
||||
get the range limit based on data calendar
|
||||
|
||||
NOTE: it is **total** range limit instead of a single step
|
||||
|
||||
The following assumptions are made
|
||||
1) The frequency of the exchange in common_infra is the same as the data calendar
|
||||
2) Users want the index mod by **day** (i.e. 240 min)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
raise_error: bool
|
||||
True: raise error if no trade_range is set
|
||||
False: return full trade calendar.
|
||||
|
||||
It is useful in following cases
|
||||
- users want to follow the order specific trading time range when decision level trade range is not
|
||||
available. Raising NotImplementedError to indicates that range limit is not available
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
the range limit in data calendar
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the following criteria meet
|
||||
1) the decision can't provide a unified start and end
|
||||
2) raise_error is True
|
||||
"""
|
||||
# potential performance issue
|
||||
day_start = pd.Timestamp(self.start_time.date())
|
||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||
freq = self.strategy.trade_exchange.freq
|
||||
_, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq)
|
||||
if self.trade_range is None:
|
||||
if raise_error:
|
||||
raise NotImplementedError(f"There is no trade_range in this case")
|
||||
else:
|
||||
return 0, day_end_idx - day_start_idx
|
||||
else:
|
||||
if rtype == "full":
|
||||
val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)
|
||||
elif rtype == "step":
|
||||
val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)
|
||||
else:
|
||||
raise ValueError(f"This type of input {rtype} is not supported")
|
||||
_, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)
|
||||
return start_idx - day_start_idx, end_index - day_start_idx
|
||||
|
||||
def empty(self) -> bool:
|
||||
for obj in self.get_decision():
|
||||
if isinstance(obj, Order):
|
||||
@@ -452,9 +520,15 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
for o in order_list:
|
||||
if o.start_time is None:
|
||||
o.start_time = start
|
||||
if o.end_time is None:
|
||||
o.end_time = end
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
|
||||
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
|
||||
|
||||
@@ -351,7 +351,10 @@ class Indicator:
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
"""Get the base volume and price information"""
|
||||
"""
|
||||
Get the base volume and price information
|
||||
All the base price values are rooted from this function
|
||||
"""
|
||||
|
||||
agg = pa_config.get("agg", "twap").lower()
|
||||
price = pa_config.get("price", "deal_price").lower()
|
||||
@@ -374,10 +377,12 @@ class Indicator:
|
||||
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# price_s = price_s.mask(np.isclose(price_s, 0))
|
||||
price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values.
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
volume_s = volume_s.reindex(price_s.index)
|
||||
elif agg == "twap":
|
||||
volume_s = pd.Series(1, index=price_s.index)
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import bisect
|
||||
from qlib.utils.time import epsilon_change
|
||||
from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -22,7 +23,11 @@ class TradeCalendarManager:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: "LevelInfrastructure" = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -36,6 +41,7 @@ class TradeCalendarManager:
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.level_infra = level_infra
|
||||
self.reset(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def reset(self, freq, start_time, end_time):
|
||||
@@ -82,19 +88,19 @@ class TradeCalendarManager:
|
||||
def get_trade_step(self):
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step=0, shift=0):
|
||||
def get_step_time(self, trade_step=None, shift=0):
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
About the endpoints:
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
|
||||
- The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
|
||||
Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
# - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_step : int, optional
|
||||
the number of trading step finished, by default 0
|
||||
the number of trading step finished, by default None to indicate current step
|
||||
shift : int, optional
|
||||
shift bars , by default 0
|
||||
|
||||
@@ -105,15 +111,43 @@ class TradeCalendarManager:
|
||||
- If shift > 0, return the trading time range of the earlier shift bars
|
||||
- If shift < 0, return the trading time range of the later shift bar
|
||||
"""
|
||||
if trade_step is None:
|
||||
trade_step = self.get_trade_step()
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
|
||||
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
|
||||
|
||||
def get_cur_step_time(self):
|
||||
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
"""
|
||||
get current step time
|
||||
get the calendar range
|
||||
The following assumptions are made
|
||||
1) The frequency of the exchange in common_infra is the same as the data calendar
|
||||
2) Users want the **data index** mod by **day** (i.e. 240 min)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
"""
|
||||
return self.get_step_time(self.get_trade_step())
|
||||
# potential performance issue
|
||||
day_start = pd.Timestamp(self.start_time.date())
|
||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
||||
_, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq)
|
||||
|
||||
if rtype == "full":
|
||||
_, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq)
|
||||
elif rtype == "step":
|
||||
_, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq)
|
||||
else:
|
||||
raise ValueError(f"This type of input {rtype} is not supported")
|
||||
|
||||
return start_idx - day_start_idx, end_index - day_start_idx
|
||||
|
||||
def get_all_time(self):
|
||||
"""Get the start_time and end_time for trading"""
|
||||
@@ -147,7 +181,7 @@ class TradeCalendarManager:
|
||||
return clip(left), clip(right)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
@@ -198,14 +232,16 @@ class LevelInfrastructure(BaseInfrastructure):
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return ["trade_calendar", "sub_level_infra"]
|
||||
return ["trade_calendar", "sub_level_infra", "common_infra"]
|
||||
|
||||
def reset_cal(self, freq, start_time, end_time):
|
||||
"""reset trade calendar manager"""
|
||||
if self.has("trade_calendar"):
|
||||
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
self.reset_infra(trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time))
|
||||
self.reset_infra(
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
|
||||
)
|
||||
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
|
||||
"""this will make the calendar access easier when acrossing multi-levels"""
|
||||
|
||||
@@ -73,20 +73,20 @@ def indicator_analysis(df, method="mean"):
|
||||
Parameters
|
||||
----------
|
||||
df : pandas.DataFrame
|
||||
columns: like ['pa', 'pos', 'ffr', 'amount', 'value'].
|
||||
columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value'].
|
||||
Necessary fields:
|
||||
- 'pa' is the price advantage in trade indicators
|
||||
- 'pos' is the positive rate in trade indicators
|
||||
- 'ffr' is the fulfill rate in trade indicators
|
||||
Optional fields:
|
||||
- 'amount' is the total deal amount, only necessary when method is 'amount_weighted'
|
||||
- 'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted'
|
||||
- 'value' is the total trade value, only necessary when method is 'value_weighted'
|
||||
|
||||
index: Index(datetime)
|
||||
method : str, optional
|
||||
statistics method of pa/ffr, by default "mean"
|
||||
- if method is 'mean', count the mean statistical value of each trade indicator
|
||||
- if method is 'amount_weighted', count the amount weighted mean statistical value of each trade indicator
|
||||
- if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator
|
||||
- if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
|
||||
Note: statistics method of pos is always "mean"
|
||||
|
||||
@@ -97,7 +97,7 @@ def indicator_analysis(df, method="mean"):
|
||||
"""
|
||||
weights_dict = {
|
||||
"mean": df["count"],
|
||||
"amount_weighted": df["amount"].abs(),
|
||||
"amount_weighted": df["deal_amount"].abs(),
|
||||
"value_weighted": df["value"].abs(),
|
||||
}
|
||||
if method not in weights_dict:
|
||||
|
||||
@@ -64,7 +64,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
@@ -73,22 +73,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.risk_degree = risk_degree
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
@@ -210,7 +194,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=Order.SELL, # 0 for sell, 1 for buy
|
||||
factor=factor,
|
||||
)
|
||||
# is order executable
|
||||
if self.trade_exchange.check_order(sell_order):
|
||||
@@ -247,7 +230,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=Order.BUY, # 1 for buy
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
return TradeDecisionWO(sell_order_list + buy_order_list, self)
|
||||
@@ -278,28 +260,12 @@ class WeightStrategyBase(ModelStrategy):
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(WeightStrategyBase, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
|
||||
@@ -20,48 +20,6 @@ from qlib.backtest.utils import get_start_end_idx
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
"""TWAP Strategy for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(TWAPStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
@@ -105,7 +63,9 @@ class TWAPStrategy(BaseStrategy):
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
@@ -141,7 +101,6 @@ class TWAPStrategy(BaseStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return TradeDecisionWO(order_list=order_list, strategy=self)
|
||||
@@ -161,46 +120,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(SBBStrategyBase, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset_common_infra(common_infra)
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
@@ -250,7 +169,9 @@ class SBBStrategyBase(BaseStrategy):
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
continue
|
||||
# get amount of one trade unit
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
if _pred_trend == self.TREND_MID:
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
@@ -283,7 +204,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction,
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
|
||||
@@ -330,7 +250,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
else:
|
||||
@@ -349,7 +268,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
|
||||
@@ -395,7 +313,9 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs)
|
||||
super(SBBStrategyEMA, self).__init__(
|
||||
outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
|
||||
def _reset_signal(self):
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
@@ -417,14 +337,8 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
reset level-shared infra
|
||||
- After reset the trade calendar, the signal will be changed
|
||||
"""
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
super().reset_level_infra(level_infra)
|
||||
self._reset_signal()
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
# if no signal, return mid trend
|
||||
@@ -484,10 +398,9 @@ class ACStrategy(BaseStrategy):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(ACStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
super(ACStrategy, self).__init__(
|
||||
outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
|
||||
def _reset_signal(self):
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
@@ -506,33 +419,13 @@ class ACStrategy(BaseStrategy):
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(ACStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
reset level-shared infra
|
||||
- After reset the trade calendar, the signal will be changed
|
||||
"""
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
super().reset_level_infra(level_infra)
|
||||
self._reset_signal()
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
@@ -578,7 +471,9 @@ class ACStrategy(BaseStrategy):
|
||||
|
||||
if sig_sam is None or np.isnan(sig_sam):
|
||||
# no signal, TWAP
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
@@ -599,7 +494,9 @@ class ACStrategy(BaseStrategy):
|
||||
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
|
||||
) / np.sinh(kappa * trade_len)
|
||||
_order_amount = order.amount * amount_ratio
|
||||
_order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, order.factor)
|
||||
_order_amount = self.trade_exchange.round_amount_by_trade_unit(
|
||||
_order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
@@ -673,8 +570,6 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
.create(
|
||||
code=stock_id,
|
||||
amount=volume * self.volume_ratio,
|
||||
start_time=step_time_start,
|
||||
end_time=step_time_end,
|
||||
direction=self.direction,
|
||||
)
|
||||
)
|
||||
@@ -734,9 +629,7 @@ class FileOrderStrategy(BaseStrategy):
|
||||
execute_result will be ignored in FileOrderStrategy
|
||||
"""
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
tc = self.trade_calendar
|
||||
step = tc.get_trade_step()
|
||||
start, end = tc.get_step_time(step)
|
||||
start, _ = self.trade_calendar.get_step_time()
|
||||
# CONVERSION: the bar is indexed by the time
|
||||
try:
|
||||
df = self.order_df.loc(axis=0)[start]
|
||||
@@ -750,8 +643,6 @@ class FileOrderStrategy(BaseStrategy):
|
||||
code=idx,
|
||||
amount=row["amount"],
|
||||
direction=Order.parse_dir(row["direction"]),
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
)
|
||||
return TradeDecisionWO(order_list, self, self.trade_range)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -22,6 +23,7 @@ class BaseStrategy:
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_exchange: Exchange = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -34,9 +36,18 @@ class BaseStrategy:
|
||||
level shared infrastructure for backtesting, including trade calendar
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
|
||||
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
self._trade_exchange = trade_exchange
|
||||
|
||||
@property
|
||||
def trade_calendar(self) -> TradeCalendarManager:
|
||||
@@ -46,6 +57,11 @@ class BaseStrategy:
|
||||
def trade_position(self) -> BasePosition:
|
||||
return self.common_infra.get("trade_account").current
|
||||
|
||||
@property
|
||||
def trade_exchange(self) -> Exchange:
|
||||
"""get trade exchange in a prioritized order"""
|
||||
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
|
||||
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
@@ -69,6 +85,24 @@ class BaseStrategy:
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
- reset `outer_trade_decision`, used to make split decision
|
||||
|
||||
**NOTE**:
|
||||
split this function into `reset` and `_reset` will make following cases more convenient
|
||||
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
|
||||
when initialization
|
||||
"""
|
||||
self._reset(
|
||||
level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
|
||||
)
|
||||
|
||||
def _reset(
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
):
|
||||
"""
|
||||
Please refer to the docs of `reset`
|
||||
"""
|
||||
if level_infra is not None:
|
||||
self.reset_level_infra(level_infra)
|
||||
@@ -124,6 +158,36 @@ class BaseStrategy:
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method")
|
||||
|
||||
# helper methods: not necessary but for convenience
|
||||
def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
"""
|
||||
return data calendar's available decision range for `self` strategy
|
||||
the range consider following factors
|
||||
- data calendar in the charge of `self` strategy
|
||||
- trading range limitation from the decision of outer strategy
|
||||
|
||||
|
||||
related methods
|
||||
- TradeCalendarManager.get_data_cal_range
|
||||
- BaseTradeDecision.get_data_cal_range_limit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the available data index range of the strategy from `start_time` to `end_time`
|
||||
- "step": return the available data index range of the strategy of current step
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
the available range both sides are closed
|
||||
"""
|
||||
cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype)
|
||||
if self.outer_trade_decision is None:
|
||||
raise ValueError(f"There is not limitation for strategy {self}")
|
||||
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
@@ -210,10 +210,13 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
if isinstance(config["class"], str):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
else:
|
||||
klass = config["class"] # the class type itself is passed in
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
@@ -235,11 +238,17 @@ def init_instance_by_config(
|
||||
----------
|
||||
config : Union[str, dict, object]
|
||||
dict example.
|
||||
case 1)
|
||||
{
|
||||
'class': 'ClassName',
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
case 2)
|
||||
{
|
||||
'class': <The class it self>,
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
}
|
||||
str example.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
|
||||
@@ -160,5 +160,32 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp:
|
||||
return concat_date_time(date, new_time)
|
||||
|
||||
|
||||
def epsilon_change(datetime: pd.Timestamp, direction: str = "backward") -> pd.Timestamp:
|
||||
"""
|
||||
change the time by infinitely small quantity.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datetime : pd.Timestamp
|
||||
the original time
|
||||
direction : str
|
||||
the direction the time are going to
|
||||
- "backward" for going to history
|
||||
- "forward" for going to the future
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Timestamp:
|
||||
the shifted time
|
||||
"""
|
||||
if direction == "backward":
|
||||
return datetime - pd.Timedelta(seconds=1)
|
||||
elif direction == "forward":
|
||||
return datetime + pd.Timedelta(seconds=1)
|
||||
else:
|
||||
raise ValueError("Wrong input")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_day_min_idx_range("8:30", "14:59", "10min"))
|
||||
|
||||
Reference in New Issue
Block a user