mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
align range limit
This commit is contained in:
@@ -13,7 +13,7 @@ from .executor import BaseExecutor
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .order import Order
|
||||
from .utils import CommonInfrastructure, TradeCalendarManager
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
import warnings
|
||||
import pandas as pd
|
||||
@@ -250,6 +250,7 @@ class Account:
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
"""update account at each trading bar step
|
||||
@@ -274,6 +275,9 @@ class Account:
|
||||
indicators of inner executor, by default None
|
||||
- necessary if atomic is False
|
||||
- used to aggregate outer indicators
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
The decision list of the inner level: List[Tuple[<decision>, <start_time>, <end_time>]]
|
||||
The inner level
|
||||
indicator_config : dict, optional
|
||||
config of calculating indicators, by default {}
|
||||
"""
|
||||
@@ -289,22 +293,27 @@ class Account:
|
||||
# report is portfolio related analysis
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.clear()
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.reset()
|
||||
|
||||
# aggregate the information for each order
|
||||
if atomic:
|
||||
self.indicator.update_order_indicators(trade_info)
|
||||
else:
|
||||
self.indicator.agg_order_indicators(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
inner_order_indicators,
|
||||
decision_list=decision_list,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
trade_exchange=trade_exchange,
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
# aggregate all the order metrics a single step
|
||||
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
|
||||
|
||||
# record the metrics
|
||||
self.indicator.record(trade_start_time)
|
||||
|
||||
def get_report(self):
|
||||
|
||||
@@ -55,14 +55,13 @@ def collect_data_loop(
|
||||
trade decision
|
||||
"""
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)
|
||||
bar.update(1)
|
||||
|
||||
if return_value is not None:
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from abc import abstractclassmethod, abstractmethod
|
||||
import copy
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import List, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .order import Order, BaseTradeDecision
|
||||
from .order import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
@@ -26,6 +29,7 @@ class BaseExecutor:
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -62,8 +66,8 @@ class BaseExecutor:
|
||||
{
|
||||
'show_indicator': True,
|
||||
'pa_config': {
|
||||
'base_value': 'twap',
|
||||
'weight_method': 'value_weighted',
|
||||
"agg": "twap", # "vwap"
|
||||
"price": "$close", # default to use deal price of the exchange
|
||||
},
|
||||
'ffr_config':{
|
||||
'weight_method': 'value_weighted',
|
||||
@@ -77,6 +81,12 @@ class BaseExecutor:
|
||||
whether to generate trade_decision, will be used when training rl agent
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to generate report
|
||||
- If generate_report is None, trade_exchange will be ignored
|
||||
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
|
||||
common_infra : CommonInfrastructure, optional:
|
||||
common infrastructure for backtesting, may including:
|
||||
- trade_account : Account, optional
|
||||
@@ -90,7 +100,9 @@ class BaseExecutor:
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
self.reset(start_time=start_time, end_time=end_time, track_data=track_data, common_infra=common_infra)
|
||||
self._trade_exchange = trade_exchange
|
||||
self.level_infra = LevelInfrastructure()
|
||||
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -105,60 +117,106 @@ class BaseExecutor:
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# copy is used instead of deepcopy. So positions are shared
|
||||
self.trade_account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report)
|
||||
|
||||
def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
@property
|
||||
def trade_exchange(self) -> Exchange:
|
||||
"""get trade exchange in a prioritized order"""
|
||||
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
|
||||
|
||||
@property
|
||||
def trade_calendar(self) -> TradeCalendarManager:
|
||||
"""
|
||||
Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the
|
||||
code easier
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `track_data`, used when making data for multi-level training
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
"""
|
||||
|
||||
if track_data is not None:
|
||||
self.track_data = track_data
|
||||
|
||||
if "start_time" in kwargs or "end_time" in kwargs:
|
||||
start_time = kwargs.get("start_time")
|
||||
end_time = kwargs.get("end_time")
|
||||
self.trade_calendar = TradeCalendarManager(
|
||||
freq=self.time_per_step, start_time=start_time, end_time=end_time
|
||||
)
|
||||
|
||||
self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time)
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
return LevelInfrastructure(trade_calendar=self.trade_calendar)
|
||||
return self.level_infra
|
||||
|
||||
def finished(self):
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision):
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
"""execute the trade decision and return the executed result
|
||||
|
||||
NOTE: this function is never used directly in the framework. Should we delete it?
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
level : int
|
||||
the level of current executor
|
||||
|
||||
Returns
|
||||
----------
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
"""
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
return_value = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
@abstractclassmethod
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
collect_data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Please refer to the doc of collect_data
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[List[object], dict]:
|
||||
(<the executed result for trade decision>, <the extra kwargs for `self.trade_account.update_bar_end`>)
|
||||
"""
|
||||
|
||||
def collect_data(
|
||||
self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
|
||||
) -> List[object]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
level : int
|
||||
the level of current executor. 0 indicates the top level
|
||||
|
||||
return_value : dict
|
||||
the mem address to return the value
|
||||
e.g. {"return_value": <the executed result>}
|
||||
|
||||
Returns
|
||||
----------
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
the executed result for trade decision.
|
||||
** NOTE!!!! **:
|
||||
1) This is necessary, The return value of geenrator will be used in NestedExecutor
|
||||
2) Please note the executed results are not merged.
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -167,7 +225,36 @@ class BaseExecutor:
|
||||
"""
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
return self.execute(trade_decision)
|
||||
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
|
||||
if atomic and trade_decision.get_range_limit(default_value=None) is not None:
|
||||
raise ValueError("atomic executor doesn't support specify `range_limit`")
|
||||
|
||||
obj = self._collect_data(trade_decision=trade_decision, level=level)
|
||||
|
||||
if isinstance(obj, GeneratorType):
|
||||
res, kwargs = yield from obj
|
||||
else:
|
||||
# Some concrete executor don't have inner decisions
|
||||
res, kwargs = obj
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
|
||||
# Account will not be changed in this function
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=atomic,
|
||||
outer_trade_decision=trade_decision,
|
||||
indicator_config=self.indicator_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.trade_calendar.step()
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": res})
|
||||
return res
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors"""
|
||||
@@ -192,7 +279,7 @@ class NestedExecutor(BaseExecutor):
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
skip_empty_decision: bool = True,
|
||||
trade_exchange: Exchange = None,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -203,24 +290,24 @@ class NestedExecutor(BaseExecutor):
|
||||
trading env in each trading bar.
|
||||
inner_strategy : BaseStrategy
|
||||
trading strategy in each trading bar
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to generate report
|
||||
- If generate_report is None, trade_exchange will be ignored
|
||||
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
skip_empty_decision: bool
|
||||
Will the executor skip the inner loop when the decision is empty.
|
||||
Will the executor skip call inner loop when the decision is empty.
|
||||
It should be False in following cases
|
||||
- The decisions may be updated by steps
|
||||
- The inner executor may not follow the decisions from the outer strategy
|
||||
align_range_limit: bool
|
||||
force to align the index_range decision
|
||||
It is only for nested executor, because range_limit is given by outer strategy
|
||||
"""
|
||||
self.inner_executor = init_instance_by_config(
|
||||
self.inner_executor: BaseExecutor = init_instance_by_config(
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
)
|
||||
self.inner_strategy = init_instance_by_config(
|
||||
self.inner_strategy: BaseStrategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
|
||||
)
|
||||
|
||||
self._skip_empty_decision = skip_empty_decision
|
||||
self._align_range_limit = align_range_limit
|
||||
|
||||
super(NestedExecutor, self).__init__(
|
||||
time_per_step=time_per_step,
|
||||
@@ -234,82 +321,82 @@ class NestedExecutor(BaseExecutor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_exchange
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
"""
|
||||
super(NestedExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
self.inner_executor.reset_common_infra(common_infra)
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time()
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
self.level_infra.set_sub_level_infra(sub_level_infra)
|
||||
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
|
||||
|
||||
def execute(self, trade_decision):
|
||||
return_value = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
# outter strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outter decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def collect_data(self, trade_decision: BaseTradeDecision, return_value=None):
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
# def _get_inner_trade_decision(self, outer_trade_decision: BaseTradeDecision, inner_execute_result):
|
||||
# # In some cases, the inner strategy can be skipped, but the inner executor should keep running
|
||||
# if outer_trade_decision.empty() and self._skip_empty_decision:
|
||||
# return EmptyTradeDecision(self.inner_strategy)
|
||||
# return self.inner_strategy.generate_trade_decision(inner_execute_result)
|
||||
# _inner_trade_decision = self._get_inner_trade_decision(trade_decision, _inner_execute_result)
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
# NOTE:
|
||||
# - this is necessary to calculating the steps in sub level
|
||||
# - more detailed information will be set into trade decision
|
||||
self._init_sub_trading(trade_decision)
|
||||
|
||||
if not (trade_decision.empty() and self._skip_empty_decision):
|
||||
_inner_execute_result = None
|
||||
self._init_sub_trading(trade_decision)
|
||||
while not self.inner_executor.finished():
|
||||
# outter strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outter decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
_inner_execute_result = None
|
||||
while not self.inner_executor.finished():
|
||||
trade_decision = self._update_trade_decision(trade_decision)
|
||||
|
||||
if trade_decision.empty() and self._skip_empty_decision:
|
||||
# give one chance for outer stategy to update the strategy
|
||||
# - For updating some information in the sub executor(the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
break
|
||||
|
||||
sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
|
||||
start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision)
|
||||
if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
|
||||
# if force align the range limit, skip the steps outside the decision range limit
|
||||
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
# NOTE sub_cal.get_cur_step_time() must be called before collect_data in case of step shifting
|
||||
decision_list.append((_inner_trade_decision, *sub_cal.get_cur_step_time()))
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
)
|
||||
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator()
|
||||
)
|
||||
else:
|
||||
# do nothing and just step forward
|
||||
sub_cal.step()
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=False,
|
||||
outer_trade_decision=trade_decision,
|
||||
inner_order_indicators=inner_order_indicators,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
|
||||
self.trade_calendar.step()
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": execute_result})
|
||||
return execute_result
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
@@ -337,17 +424,13 @@ class SimulatorExecutor(BaseExecutor):
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_PARAL,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
trade_type: str
|
||||
please refer to the doc of `TT_SERIAL` & `TT_PARAL`
|
||||
"""
|
||||
@@ -362,20 +445,9 @@ class SimulatorExecutor(BaseExecutor):
|
||||
common_infra=common_infra,
|
||||
**kwargs,
|
||||
)
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
self.trade_type = trade_type
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_exchange
|
||||
"""
|
||||
super(SimulatorExecutor, self).reset_common_infra(common_infra)
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]:
|
||||
"""
|
||||
|
||||
@@ -405,10 +477,9 @@ class SimulatorExecutor(BaseExecutor):
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def execute(self, trade_decision: BaseTradeDecision):
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
trade_start_time, _ = self.trade_calendar.get_cur_step_time()
|
||||
execute_result = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
@@ -450,16 +521,4 @@ class SimulatorExecutor(BaseExecutor):
|
||||
print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
|
||||
# Account will not be changed in this function
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=True,
|
||||
outer_trade_decision=trade_decision,
|
||||
trade_info=execute_result,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
self.trade_calendar.step()
|
||||
return execute_result
|
||||
return execute_result, {"trade_info": execute_result}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -179,7 +180,7 @@ class BaseTradeDecision:
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
def __init__(self, strategy: BaseStrategy, idx_range: Tuple[int, int] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -187,6 +188,8 @@ class BaseTradeDecision:
|
||||
The strategy who make the decision
|
||||
"""
|
||||
self.strategy = strategy
|
||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
self.idx_range = idx_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
@@ -207,7 +210,11 @@ class BaseTradeDecision:
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
Be called at the **start** of each step.
|
||||
|
||||
This function is designn for following purpose
|
||||
1) Leave a hook for the strategy who make `self` decision to update the decision itself
|
||||
2) Update some information from the inner executor calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -221,13 +228,27 @@ class BaseTradeDecision:
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
"""
|
||||
# purpose 1)
|
||||
self.total_step = trade_calendar.get_trade_len()
|
||||
if self.idx_range is not None:
|
||||
logger = get_module_logger("decision")
|
||||
start_idx, end_idx = self.idx_range
|
||||
if start_idx < 0 or end_idx >= self.total_step:
|
||||
logger.warning(f"{self.idx_range} go beyound the total_step({self.total_step}), it will be clipped")
|
||||
self.idx_range = max(0, start_idx), min(self.total_step - 1, end_idx)
|
||||
|
||||
# purpose 2)
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
|
||||
**kwargs:
|
||||
{"default_value": <default_value>}
|
||||
# using dict is for distinguish no value provided or None provided
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
@@ -235,12 +256,32 @@ class BaseTradeDecision:
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
If the following criteria meet
|
||||
1) the decision can't provide a unified start and end
|
||||
2) default_value is None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
if self.idx_range is None:
|
||||
if "default_value" in kwargs:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
return self.idx_range
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.get_decision()) == 0
|
||||
for obj in self.get_decision():
|
||||
if isinstance(obj, Order):
|
||||
# Zero amount order will be treated as empty
|
||||
if not np.isclose(obj.amount, 0.0):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
def empty(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
@@ -249,16 +290,9 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple = None):
|
||||
super().__init__(strategy)
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple[int, int] = None):
|
||||
super().__init__(strategy, idx_range=idx_range)
|
||||
self.order_list = order_list
|
||||
self.idx_range = idx_range
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
if self.idx_range is None:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
return self.idx_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
return self.order_list
|
||||
|
||||
@@ -4,21 +4,23 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
from logging import warning
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from typing import Dict, List
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pathlib
|
||||
from typing import Dict, List, Tuple
|
||||
import warnings
|
||||
from pandas.core import groupby
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.core import groupby
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
class Report:
|
||||
@@ -251,14 +253,21 @@ class Indicator:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# order indicator is metrics for a single order for a specific step
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator = OrderedDict()
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator = OrderedDict()
|
||||
self.order_indicator: Dict[str, pd.Series] = OrderedDict()
|
||||
|
||||
def clear(self):
|
||||
# trade indicator is metrics for all orders for a specific step
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator: Dict[str, float] = OrderedDict()
|
||||
|
||||
self._trade_calendar = None
|
||||
|
||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||
def reset(self):
|
||||
self.order_indicator = OrderedDict()
|
||||
self.trade_indicator = OrderedDict()
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
def record(self, trade_start_time):
|
||||
self.order_indicator_his[trade_start_time] = self.order_indicator
|
||||
@@ -294,9 +303,14 @@ class Indicator:
|
||||
def _update_order_price_advantage(self):
|
||||
# NOTE:
|
||||
# trade_price and baseline price will be same on the lowest-level
|
||||
# So Pa should be 0
|
||||
# So Pa should be 0 or do nothing
|
||||
self.order_indicator["pa"] = 0
|
||||
|
||||
def update_order_indicators(self, trade_info: list):
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
self._update_order_price_advantage()
|
||||
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
|
||||
inner_amount = pd.Series()
|
||||
deal_amount = pd.Series()
|
||||
@@ -312,7 +326,7 @@ class Indicator:
|
||||
)
|
||||
trade_value = trade_value.add(_order_indicator["trade_value"], fill_value=0)
|
||||
trade_cost = trade_cost.add(_order_indicator["trade_cost"], fill_value=0)
|
||||
trade_dir = trade_dir.add(_order_indicator["trade_dir"])
|
||||
trade_dir = trade_dir.add(_order_indicator["trade_dir"], fill_value=0)
|
||||
|
||||
trade_dir = trade_dir.apply(Order.parse_dir)
|
||||
|
||||
@@ -335,24 +349,77 @@ class Indicator:
|
||||
def _agg_order_fulfill_rate(self):
|
||||
self.order_indicator["ffr"] = self.order_indicator["deal_amount"] / self.order_indicator["amount"]
|
||||
|
||||
def _agg_order_price_advantage(
|
||||
def _get_base_vol_pri(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
inst: str,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
"""Get the base volume and price information"""
|
||||
|
||||
agg = pa_config.get("agg", "twap").lower()
|
||||
price = pa_config.get("price", "deal_price").lower()
|
||||
|
||||
if price == "deal_price":
|
||||
price_s = trade_exchange.get_deal_price(
|
||||
inst, trade_start_time, trade_end_time, direction=direction, method=None
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# price_s = price_s.mask(np.isclose(price_s, 0))
|
||||
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
elif agg == "twap":
|
||||
volume_s = pd.Series(1, index=price_s.index)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# no sub executor on the lowest level
|
||||
# So range_limit an total step will all be None
|
||||
total_step = decision.total_step
|
||||
if total_step is None:
|
||||
total_step = 1
|
||||
range_limit = decision.get_range_limit(default_value=(0, total_step - 1))
|
||||
|
||||
assert volume_s.shape[0] % total_step == 0, "The price series can't be divided by step length"
|
||||
factor = volume_s.shape[0] // total_step
|
||||
|
||||
slc = slice(range_limit[0] * factor, (range_limit[1] + 1) * factor)
|
||||
|
||||
volume_s = volume_s.iloc[slc]
|
||||
price_s = price_s.iloc[slc]
|
||||
|
||||
base_volume = volume_s.sum().item()
|
||||
base_price = ((price_s * volume_s).sum() / base_volume).item()
|
||||
|
||||
return base_price, base_volume
|
||||
|
||||
def _agg_base_price(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
"""
|
||||
# NOTE:!!!!
|
||||
# Strong assumption!!!!!!
|
||||
# the correctness of the base_price relies on that the **same** exchange is used
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_order_indicators : List[Dict[str, pd.Series]]
|
||||
the indicators of account of inner executor
|
||||
trade_start_time : pd.Timestamp
|
||||
the start_time of the trade period, for slicing
|
||||
trade_end_time : pd.Timestamp
|
||||
the end_time of the trade period, for slicing (so it may include more time at the end)
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
a list of decisions according to inner_order_indicators
|
||||
trade_exchange : Exchange
|
||||
for retrieving trading price
|
||||
pa_config : dict
|
||||
@@ -362,32 +429,61 @@ class Indicator:
|
||||
"price": "$close", # TODO: this is not supported now!!!!!
|
||||
# default to use deal price of the exchange
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
agg = pa_config.get("agg", "twap").lower()
|
||||
price = pa_config.get("price", "deal_price").lower()
|
||||
# TODO: I think there are potentials to be optimized
|
||||
trade_dir = self.order_indicator["trade_dir"]
|
||||
if len(trade_dir) > 0:
|
||||
bp_all, bv_all = [], []
|
||||
# <step, inst, (base_volume | base_price)>
|
||||
for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
|
||||
bp_s = oi.get("base_price", pd.Series()).reindex(trade_dir.index)
|
||||
bv_s = oi.get("base_volume", pd.Series()).reindex(trade_dir.index)
|
||||
bp_new, bv_new = {}, {}
|
||||
for pr, v, (inst, direction) in zip(bp_s.values, bv_s.values, trade_dir.items()):
|
||||
if np.isnan(pr):
|
||||
bp_new[inst], bv_new[inst] = self._get_base_vol_pri(
|
||||
inst,
|
||||
start,
|
||||
end,
|
||||
decision=dec,
|
||||
direction=direction,
|
||||
trade_exchange=trade_exchange,
|
||||
pa_config=pa_config,
|
||||
)
|
||||
else:
|
||||
bp_new[inst], bv_new[inst] = pr, v
|
||||
|
||||
base_price = {}
|
||||
for inst, dir in self.order_indicator["trade_dir"].items():
|
||||
bp_new, bv_new = pd.Series(bp_new), pd.Series(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = pd.concat(bp_all, axis=1)
|
||||
bv_all = pd.concat(bv_all, axis=1)
|
||||
|
||||
if price == "deal_price":
|
||||
price_s = trade_exchange.get_deal_price(inst, trade_start_time, trade_end_time, dir, method=None)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self.order_indicator["base_volume"] = bv_all.sum(axis=1)
|
||||
self.order_indicator["base_price"] = (bp_all * bv_all).sum(axis=1) / self.order_indicator["base_volume"]
|
||||
|
||||
# there are some zeros in the trading price. These cases are known meaningless
|
||||
price_s = price_s.mask(np.isclose(price_s, 0))
|
||||
def _agg_order_price_advantage(self):
|
||||
if not self.order_indicator["trade_price"].empty:
|
||||
self.order_indicator["pa"] = self.order_indicator["trade_price"] / self.order_indicator["base_price"] - 1
|
||||
else:
|
||||
self.order_indicator["pa"] = pd.Series()
|
||||
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
base_price[inst] = ((price_s * volume_s).sum() / volume_s.sum()).item()
|
||||
elif agg == "twap":
|
||||
base_price[inst] = price_s.mean().item()
|
||||
|
||||
base_price = pd.Series(base_price)
|
||||
|
||||
# update PA
|
||||
self.order_indicator["pa"] = self.order_indicator["trade_price"] / base_price - 1
|
||||
def agg_order_indicators(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
indicator_config={},
|
||||
):
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._update_trade_amount(outer_trade_decision)
|
||||
self._agg_order_fulfill_rate()
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config)
|
||||
self._agg_order_price_advantage()
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
if method == "mean":
|
||||
@@ -402,7 +498,7 @@ class Indicator:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1)
|
||||
pa_order = self.order_indicator["pa"] * (1 - self.order_indicator["trade_dir"] * 2)
|
||||
if method == "mean":
|
||||
return pa_order.mean()
|
||||
elif method == "amount_weighted":
|
||||
@@ -427,28 +523,6 @@ class Indicator:
|
||||
def _cal_trade_order_count(self):
|
||||
return self.order_indicator["amount"].count()
|
||||
|
||||
def update_order_indicators(self, trade_info: list):
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
self._update_order_price_advantage()
|
||||
|
||||
def agg_order_indicators(
|
||||
self,
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
indicator_config={},
|
||||
):
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._update_trade_amount(outer_trade_decision)
|
||||
self._agg_order_fulfill_rate()
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
self._agg_order_price_advantage(
|
||||
inner_order_indicators, trade_start_time, trade_end_time, trade_exchange, pa_config=pa_config
|
||||
)
|
||||
|
||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
||||
show_indicator = indicator_config.get("show_indicator", False)
|
||||
ffr_config = indicator_config.get("ffr_config", {})
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Tuple, Union, List, Set
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
from ..data.data import Cal
|
||||
@@ -30,17 +35,20 @@ class TradeCalendarManager:
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
|
||||
self.reset(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def _init_trade_calendar(self, freq, start_time, end_time):
|
||||
def reset(self, freq, start_time, end_time):
|
||||
"""
|
||||
Please refer to the docs of `__init__`
|
||||
|
||||
Reset the trade calendar
|
||||
- self.trade_len : The total count for trading step
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
@@ -67,6 +75,7 @@ class TradeCalendarManager:
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
"""get the total step length"""
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_step(self):
|
||||
@@ -99,6 +108,12 @@ class TradeCalendarManager:
|
||||
calendar_index = self.start_index + trade_step
|
||||
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
|
||||
|
||||
def get_cur_step_time(self):
|
||||
"""
|
||||
get current step time
|
||||
"""
|
||||
return self.get_step_time(self.get_trade_step())
|
||||
|
||||
def get_all_time(self):
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
@@ -146,5 +161,40 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level instrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
return ["trade_calendar"]
|
||||
return ["trade_calendar", "sub_level_infra"]
|
||||
|
||||
def reset_cal(self, freq, start_time, end_time):
|
||||
"""reset trade calendar manager"""
|
||||
if self.has("trade_calendar"):
|
||||
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
self.reset_infra(trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time))
|
||||
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
|
||||
"""this will make the calendar access easier when acrossing multi-levels"""
|
||||
self.reset_infra(sub_level_infra=sub_level_infra)
|
||||
|
||||
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision made by outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, int]:
|
||||
start index and end index
|
||||
"""
|
||||
try:
|
||||
return outer_trade_decision.get_range_limit()
|
||||
except NotImplementedError:
|
||||
return 0, trade_calendar.get_trade_len() - 1
|
||||
|
||||
@@ -14,29 +14,7 @@ from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
from ...backtest.exchange import Exchange, OrderHelper
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
from qlib.utils.file import get_io_object
|
||||
|
||||
|
||||
def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
the inner strawtegy
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision made by outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, int]:
|
||||
start index and end index
|
||||
"""
|
||||
try:
|
||||
return outer_trade_decision.get_range_limit()
|
||||
except NotImplementedError:
|
||||
return 0, strategy.trade_calendar.get_trade_len() - 1
|
||||
from qlib.backtest.utils import get_start_end_idx
|
||||
|
||||
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
@@ -105,7 +83,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
start_idx, end_idx = get_start_end_idx(self, self.outer_trade_decision)
|
||||
start_idx, end_idx = get_start_end_idx(self.trade_calendar, self.outer_trade_decision)
|
||||
trade_len = end_idx - start_idx + 1
|
||||
|
||||
if trade_step < start_idx or trade_step > end_idx:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
@@ -37,24 +38,26 @@ class BaseStrategy:
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
|
||||
@property
|
||||
def trade_calendar(self) -> TradeCalendarManager:
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
@property
|
||||
def trade_position(self) -> BasePosition:
|
||||
return self.common_infra.get("trade_account").current
|
||||
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar: TradeCalendarManager = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure):
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra: CommonInfrastructure = common_infra
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
self.trade_position = common_infra.get("trade_account").current
|
||||
|
||||
def reset(
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
|
||||
Reference in New Issue
Block a user