diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index fa57e354b..ab3d29408 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -13,7 +13,7 @@ from .executor import BaseExecutor from .backtest import backtest_loop from .backtest import collect_data_loop from .order import Order -from .utils import CommonInfrastructure, TradeCalendarManager +from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..utils import init_instance_by_config from ..log import get_module_logger from ..config import C diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 67f7b056a..3ef1cdd03 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -3,7 +3,7 @@ import copy -from typing import Dict, List +from typing import Dict, List, Tuple from qlib.utils import init_instance_by_config import warnings import pandas as pd @@ -250,6 +250,7 @@ class Account: outer_trade_decision: BaseTradeDecision, trade_info: list = None, inner_order_indicators: List[Dict[str, pd.Series]] = None, + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, indicator_config: dict = {}, ): """update account at each trading bar step @@ -274,6 +275,9 @@ class Account: indicators of inner executor, by default None - necessary if atomic is False - used to aggregate outer indicators + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, + The decision list of the inner level: List[Tuple[, , ]] + The inner level indicator_config : dict, optional config of calculating indicators, by default {} """ @@ -289,22 +293,27 @@ class Account: # report is portfolio related analysis self.update_report(trade_start_time, trade_end_time) - # indicator is trading (e.g. high-frequency order execution) related analysis - self.indicator.clear() + # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():` + # indicator is trading (e.g. high-frequency order execution) related analysis + self.indicator.reset() + + # aggregate the information for each order if atomic: self.indicator.update_order_indicators(trade_info) else: self.indicator.agg_order_indicators( - trade_start_time, - trade_end_time, inner_order_indicators, + decision_list=decision_list, outer_trade_decision=outer_trade_decision, trade_exchange=trade_exchange, indicator_config=indicator_config, ) + # aggregate all the order metrics a single step self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config) + + # record the metrics self.indicator.record(trade_start_time) def get_report(self): diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 573c874b0..89b8c7830 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -55,14 +55,13 @@ def collect_data_loop( trade decision """ trade_executor.reset(start_time=start_time, end_time=end_time) - level_infra = trade_executor.get_level_infra() - trade_strategy.reset(level_infra=level_infra) + trade_strategy.reset(level_infra=trade_executor.get_level_infra()) with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar: _execute_result = None while not trade_executor.finished(): _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result) - _execute_result = yield from trade_executor.collect_data(_trade_decision) + _execute_result = yield from trade_executor.collect_data(_trade_decision, level=0) bar.update(1) if return_value is not None: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index c4807ebde..b99380c54 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -1,13 +1,16 @@ +from abc import abstractclassmethod, abstractmethod import copy +from types import GeneratorType +from qlib.backtest.account import Account import warnings import pandas as pd -from typing import List, Union +from typing import List, Tuple, Union from qlib.backtest.report import Indicator -from .order import Order, BaseTradeDecision +from .order import EmptyTradeDecision, Order, BaseTradeDecision from .exchange import Exchange -from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure +from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx from ..utils import init_instance_by_config from ..utils.time import Freq @@ -26,6 +29,7 @@ class BaseExecutor: generate_report: bool = False, verbose: bool = False, track_data: bool = False, + trade_exchange: Exchange = None, common_infra: CommonInfrastructure = None, **kwargs, ): @@ -62,8 +66,8 @@ class BaseExecutor: { 'show_indicator': True, 'pa_config': { - 'base_value': 'twap', - 'weight_method': 'value_weighted', + "agg": "twap", # "vwap" + "price": "$close", # default to use deal price of the exchange }, 'ffr_config':{ 'weight_method': 'value_weighted', @@ -77,6 +81,12 @@ class BaseExecutor: whether to generate trade_decision, will be used when training rl agent - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data` - Else, `trade_decision` will not be generated + + trade_exchange : Exchange + exchange that provides market info, used to generate report + - If generate_report is None, trade_exchange will be ignored + - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra + common_infra : CommonInfrastructure, optional: common infrastructure for backtesting, may including: - trade_account : Account, optional @@ -90,7 +100,9 @@ class BaseExecutor: self.generate_report = generate_report self.verbose = verbose self.track_data = track_data - self.reset(start_time=start_time, end_time=end_time, track_data=track_data, common_infra=common_infra) + self._trade_exchange = trade_exchange + self.level_infra = LevelInfrastructure() + self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra) def reset_common_infra(self, common_infra): """ @@ -105,60 +117,106 @@ class BaseExecutor: if common_infra.has("trade_account"): # NOTE: there is a trick in the code. # copy is used instead of deepcopy. So positions are shared - self.trade_account = copy.copy(common_infra.get("trade_account")) + self.trade_account: Account = copy.copy(common_infra.get("trade_account")) self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report) - def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs): + @property + def trade_exchange(self) -> Exchange: + """get trade exchange in a prioritized order""" + return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange") + + @property + def trade_calendar(self) -> TradeCalendarManager: + """ + Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the + code easier + """ + return self.level_infra.get("trade_calendar") + + def reset(self, common_infra: CommonInfrastructure = None, **kwargs): """ - reset `start_time` and `end_time`, used in trade calendar - - reset `track_data`, used when making data for multi-level training - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc """ - if track_data is not None: - self.track_data = track_data - if "start_time" in kwargs or "end_time" in kwargs: start_time = kwargs.get("start_time") end_time = kwargs.get("end_time") - self.trade_calendar = TradeCalendarManager( - freq=self.time_per_step, start_time=start_time, end_time=end_time - ) - + self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time) if common_infra is not None: self.reset_common_infra(common_infra) def get_level_infra(self): - return LevelInfrastructure(trade_calendar=self.trade_calendar) + return self.level_infra def finished(self): return self.trade_calendar.finished() - def execute(self, trade_decision): + def execute(self, trade_decision: BaseTradeDecision, level: int = 0): """execute the trade decision and return the executed result + NOTE: this function is never used directly in the framework. Should we delete it? + Parameters ---------- trade_decision : BaseTradeDecision + level : int + the level of current executor + Returns ---------- execute_result : List[object] the executed result for trade decision """ - raise NotImplementedError("execute is not implemented!") + return_value = {} + for _decision in self.collect_data(trade_decision, return_value=return_value, level=level): + pass + return return_value.get("execute_result") - def collect_data(self, trade_decision): + @abstractclassmethod + def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: + """ + Please refer to the doc of collect_data + The only difference between `_collect_data` and `collect_data` is that some common steps are moved into + collect_data + + Parameters + ---------- + Please refer to the doc of collect_data + + + Returns + ------- + Tuple[List[object], dict]: + (, ) + """ + + def collect_data( + self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0 + ) -> List[object]: """Generator for collecting the trade decision data for rl training + his function will make a step forward + Parameters ---------- trade_decision : BaseTradeDecision + level : int + the level of current executor. 0 indicates the top level + + return_value : dict + the mem address to return the value + e.g. {"return_value": } + Returns ---------- execute_result : List[object] - the executed result for trade decision + the executed result for trade decision. + ** NOTE!!!! **: + 1) This is necessary, The return value of geenrator will be used in NestedExecutor + 2) Please note the executed results are not merged. Yields ------- @@ -167,7 +225,36 @@ class BaseExecutor: """ if self.track_data: yield trade_decision - return self.execute(trade_decision) + + atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True + + if atomic and trade_decision.get_range_limit(default_value=None) is not None: + raise ValueError("atomic executor doesn't support specify `range_limit`") + + obj = self._collect_data(trade_decision=trade_decision, level=level) + + if isinstance(obj, GeneratorType): + res, kwargs = yield from obj + else: + # Some concrete executor don't have inner decisions + res, kwargs = obj + + trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time() + # Account will not be changed in this function + self.trade_account.update_bar_end( + trade_start_time, + trade_end_time, + self.trade_exchange, + atomic=atomic, + outer_trade_decision=trade_decision, + indicator_config=self.indicator_config, + **kwargs, + ) + + self.trade_calendar.step() + if return_value is not None: + return_value.update({"execute_result": res}) + return res def get_all_executors(self): """get all executors""" @@ -192,7 +279,7 @@ class NestedExecutor(BaseExecutor): verbose: bool = False, track_data: bool = False, skip_empty_decision: bool = True, - trade_exchange: Exchange = None, + align_range_limit: bool = True, common_infra: CommonInfrastructure = None, **kwargs, ): @@ -203,24 +290,24 @@ class NestedExecutor(BaseExecutor): trading env in each trading bar. inner_strategy : BaseStrategy trading strategy in each trading bar - trade_exchange : Exchange - exchange that provides market info, used to generate report - - If generate_report is None, trade_exchange will be ignored - - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra skip_empty_decision: bool - Will the executor skip the inner loop when the decision is empty. + Will the executor skip call inner loop when the decision is empty. It should be False in following cases - The decisions may be updated by steps - The inner executor may not follow the decisions from the outer strategy + align_range_limit: bool + force to align the index_range decision + It is only for nested executor, because range_limit is given by outer strategy """ - self.inner_executor = init_instance_by_config( + self.inner_executor: BaseExecutor = init_instance_by_config( inner_executor, common_infra=common_infra, accept_types=BaseExecutor ) - self.inner_strategy = init_instance_by_config( + self.inner_strategy: BaseStrategy = init_instance_by_config( inner_strategy, common_infra=common_infra, accept_types=BaseStrategy ) self._skip_empty_decision = skip_empty_decision + self._align_range_limit = align_range_limit super(NestedExecutor, self).__init__( time_per_step=time_per_step, @@ -234,82 +321,82 @@ class NestedExecutor(BaseExecutor): **kwargs, ) - if trade_exchange is not None: - self.trade_exchange = trade_exchange - def reset_common_infra(self, common_infra): """ reset infrastructure for trading - - reset trade_exchange - reset inner_strategyand inner_executor common infra """ super(NestedExecutor, self).reset_common_infra(common_infra) - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") - self.inner_executor.reset_common_infra(common_infra) self.inner_strategy.reset_common_infra(common_infra) def _init_sub_trading(self, trade_decision): - trade_step = self.trade_calendar.get_trade_step() - trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time() self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) sub_level_infra = self.inner_executor.get_level_infra() + self.level_infra.set_sub_level_infra(sub_level_infra) self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision) - def execute(self, trade_decision): - return_value = {} - for _decision in self.collect_data(trade_decision, return_value): - pass - return return_value.get("execute_result") + def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision: + # outter strategy have chance to update decision each iterator + updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar) + if updated_trade_decision is not None: + trade_decision = updated_trade_decision + # NEW UPDATE + # create a hook for inner strategy to update outter decision + self.inner_strategy.alter_outer_trade_decision(trade_decision) + return trade_decision - def collect_data(self, trade_decision: BaseTradeDecision, return_value=None): - if self.track_data: - yield trade_decision + # def _get_inner_trade_decision(self, outer_trade_decision: BaseTradeDecision, inner_execute_result): + # # In some cases, the inner strategy can be skipped, but the inner executor should keep running + # if outer_trade_decision.empty() and self._skip_empty_decision: + # return EmptyTradeDecision(self.inner_strategy) + # return self.inner_strategy.generate_trade_decision(inner_execute_result) + # _inner_trade_decision = self._get_inner_trade_decision(trade_decision, _inner_execute_result) + + def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): execute_result = [] inner_order_indicators = [] + decision_list = [] + # NOTE: + # - this is necessary to calculating the steps in sub level + # - more detailed information will be set into trade decision + self._init_sub_trading(trade_decision) - if not (trade_decision.empty() and self._skip_empty_decision): - _inner_execute_result = None - self._init_sub_trading(trade_decision) - while not self.inner_executor.finished(): - # outter strategy have chance to update decision each iterator - updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar) - if updated_trade_decision is not None: - trade_decision = updated_trade_decision - # NEW UPDATE - # create a hook for inner strategy to update outter decision - self.inner_strategy.alter_outer_trade_decision(trade_decision) + _inner_execute_result = None + while not self.inner_executor.finished(): + trade_decision = self._update_trade_decision(trade_decision) + + if trade_decision.empty() and self._skip_empty_decision: + # give one chance for outer stategy to update the strategy + # - For updating some information in the sub executor(the strategy have no knowledge of the inner + # executor when generating the decision) + break + + sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar + start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision) + if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx: + # if force align the range limit, skip the steps outside the decision range limit _inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result) + # NOTE sub_cal.get_cur_step_time() must be called before collect_data in case of step shifting + decision_list.append((_inner_trade_decision, *sub_cal.get_cur_step_time())) # NOTE: Trade Calendar will step forward in the follow line _inner_execute_result = yield from self.inner_executor.collect_data( - trade_decision=_inner_trade_decision + trade_decision=_inner_trade_decision, level=level + 1 ) - execute_result.extend(_inner_execute_result) + inner_order_indicators.append( self.inner_executor.trade_account.get_trade_indicator().get_order_indicator() ) + else: + # do nothing and just step forward + sub_cal.step() - trade_step = self.trade_calendar.get_trade_step() - trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - self.trade_account.update_bar_end( - trade_start_time, - trade_end_time, - self.trade_exchange, - atomic=False, - outer_trade_decision=trade_decision, - inner_order_indicators=inner_order_indicators, - indicator_config=self.indicator_config, - ) - - self.trade_calendar.step() - if return_value is not None: - return_value.update({"execute_result": execute_result}) - return execute_result + return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list} def get_all_executors(self): """get all executors, including self and inner_executor.get_all_executors()""" @@ -337,17 +424,13 @@ class SimulatorExecutor(BaseExecutor): generate_report: bool = False, verbose: bool = False, track_data: bool = False, - trade_exchange: Exchange = None, common_infra: CommonInfrastructure = None, - trade_type: str = TT_PARAL, + trade_type: str = TT_SERIAL, **kwargs, ): """ Parameters ---------- - trade_exchange : Exchange - exchange that provides market info, used to deal order and generate report - - If `trade_exchange` is None, self.trade_exchange will be set with common_infra trade_type: str please refer to the doc of `TT_SERIAL` & `TT_PARAL` """ @@ -362,20 +445,9 @@ class SimulatorExecutor(BaseExecutor): common_infra=common_infra, **kwargs, ) - if trade_exchange is not None: - self.trade_exchange = trade_exchange self.trade_type = trade_type - def reset_common_infra(self, common_infra): - """ - reset infrastructure for trading - - reset trade_exchange - """ - super(SimulatorExecutor, self).reset_common_infra(common_infra) - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") - def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: """ @@ -405,10 +477,9 @@ class SimulatorExecutor(BaseExecutor): raise NotImplementedError(f"This type of input is not supported") return order_it - def execute(self, trade_decision: BaseTradeDecision): + def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): - trade_step = self.trade_calendar.get_trade_step() - trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + trade_start_time, _ = self.trade_calendar.get_cur_step_time() execute_result = [] for order in self._get_order_iterator(trade_decision): @@ -450,16 +521,4 @@ class SimulatorExecutor(BaseExecutor): print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id)) # do nothing pass - - # Account will not be changed in this function - self.trade_account.update_bar_end( - trade_start_time, - trade_end_time, - self.trade_exchange, - atomic=True, - outer_trade_decision=trade_decision, - trade_info=execute_result, - indicator_config=self.indicator_config, - ) - self.trade_calendar.step() - return execute_result + return execute_result, {"trade_info": execute_result} diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 20c97aa90..1a88ded93 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -3,6 +3,7 @@ # TODO: rename it with decision.py from __future__ import annotations from enum import IntEnum +from qlib.log import get_module_logger # try to fix circular imports when enabling type hints from typing import TYPE_CHECKING @@ -179,7 +180,7 @@ class BaseTradeDecision: 2. Same as `case 1.3` """ - def __init__(self, strategy: BaseStrategy): + def __init__(self, strategy: BaseStrategy, idx_range: Tuple[int, int] = None): """ Parameters ---------- @@ -187,6 +188,8 @@ class BaseTradeDecision: The strategy who make the decision """ self.strategy = strategy + self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading` + self.idx_range = idx_range def get_decision(self) -> List[object]: """ @@ -207,7 +210,11 @@ class BaseTradeDecision: def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]: """ - Be called at the **start** of each step + Be called at the **start** of each step. + + This function is designn for following purpose + 1) Leave a hook for the strategy who make `self` decision to update the decision itself + 2) Update some information from the inner executor calendar Parameters ---------- @@ -221,13 +228,27 @@ class BaseTradeDecision: BaseTradeDecision: New update, use new decision """ + # purpose 1) + self.total_step = trade_calendar.get_trade_len() + if self.idx_range is not None: + logger = get_module_logger("decision") + start_idx, end_idx = self.idx_range + if start_idx < 0 or end_idx >= self.total_step: + logger.warning(f"{self.idx_range} go beyound the total_step({self.total_step}), it will be clipped") + self.idx_range = max(0, start_idx), min(self.total_step - 1, end_idx) + + # purpose 2) return self.strategy.update_trade_decision(self, trade_calendar) - def get_range_limit(self) -> Tuple[int, int]: + def get_range_limit(self, **kwargs) -> Tuple[int, int]: """ return the expected step range for limiting the decision execution time Both left and right are **closed** + **kwargs: + {"default_value": } + # using dict is for distinguish no value provided or None provided + Returns ------- Tuple[int, int]: @@ -235,12 +256,32 @@ class BaseTradeDecision: Raises ------ NotImplementedError: - If the decision can't provide a unified start and end + If the following criteria meet + 1) the decision can't provide a unified start and end + 2) default_value is None """ - raise NotImplementedError(f"Please implement the `func` method") + if self.idx_range is None: + if "default_value" in kwargs: + return kwargs["default_value"] + else: + # Default to get full index + raise NotImplementedError(f"The decision didn't provide an index range") + return self.idx_range def empty(self) -> bool: - return len(self.get_decision()) == 0 + for obj in self.get_decision(): + if isinstance(obj, Order): + # Zero amount order will be treated as empty + if not np.isclose(obj.amount, 0.0): + return False + else: + return True + return True + + +class EmptyTradeDecision(BaseTradeDecision): + def empty(self) -> bool: + return True class TradeDecisionWO(BaseTradeDecision): @@ -249,16 +290,9 @@ class TradeDecisionWO(BaseTradeDecision): Besides, the time_range is also included. """ - def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple = None): - super().__init__(strategy) + def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple[int, int] = None): + super().__init__(strategy, idx_range=idx_range) self.order_list = order_list - self.idx_range = idx_range - - def get_range_limit(self) -> Tuple[int, int]: - if self.idx_range is None: - # Default to get full index - raise NotImplementedError(f"The decision didn't provide an index range") - return self.idx_range def get_decision(self) -> List[object]: return self.order_list diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 43a6a455b..138a44faa 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -4,21 +4,23 @@ from collections import OrderedDict from logging import warning -from qlib.backtest.exchange import Exchange -from typing import Dict, List -from qlib.backtest.order import BaseTradeDecision, Order, OrderDir -import pandas as pd -import numpy as np import pathlib +from typing import Dict, List, Tuple import warnings -from pandas.core import groupby +import numpy as np +import pandas as pd +from pandas.core import groupby from pandas.core.frame import DataFrame -from ..utils.time import Freq -from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature +from qlib.backtest.exchange import Exchange +from qlib.backtest.order import BaseTradeDecision, Order, OrderDir +from qlib.backtest.utils import TradeCalendarManager + from ..data import D from ..tests.config import CSI300_BENCH +from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data +from ..utils.time import Freq class Report: @@ -251,14 +253,21 @@ class Indicator: """ def __init__(self): + # order indicator is metrics for a single order for a specific step self.order_indicator_his = OrderedDict() - self.order_indicator = OrderedDict() - self.trade_indicator_his = OrderedDict() - self.trade_indicator = OrderedDict() + self.order_indicator: Dict[str, pd.Series] = OrderedDict() - def clear(self): + # trade indicator is metrics for all orders for a specific step + self.trade_indicator_his = OrderedDict() + self.trade_indicator: Dict[str, float] = OrderedDict() + + self._trade_calendar = None + + # def reset(self, trade_calendar: TradeCalendarManager): + def reset(self): self.order_indicator = OrderedDict() self.trade_indicator = OrderedDict() + # self._trade_calendar = trade_calendar def record(self, trade_start_time): self.order_indicator_his[trade_start_time] = self.order_indicator @@ -294,9 +303,14 @@ class Indicator: def _update_order_price_advantage(self): # NOTE: # trade_price and baseline price will be same on the lowest-level - # So Pa should be 0 + # So Pa should be 0 or do nothing self.order_indicator["pa"] = 0 + def update_order_indicators(self, trade_info: list): + self._update_order_trade_info(trade_info=trade_info) + self._update_order_fulfill_rate() + self._update_order_price_advantage() + def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]): inner_amount = pd.Series() deal_amount = pd.Series() @@ -312,7 +326,7 @@ class Indicator: ) trade_value = trade_value.add(_order_indicator["trade_value"], fill_value=0) trade_cost = trade_cost.add(_order_indicator["trade_cost"], fill_value=0) - trade_dir = trade_dir.add(_order_indicator["trade_dir"]) + trade_dir = trade_dir.add(_order_indicator["trade_dir"], fill_value=0) trade_dir = trade_dir.apply(Order.parse_dir) @@ -335,24 +349,77 @@ class Indicator: def _agg_order_fulfill_rate(self): self.order_indicator["ffr"] = self.order_indicator["deal_amount"] / self.order_indicator["amount"] - def _agg_order_price_advantage( + def _get_base_vol_pri( self, - inner_order_indicators: List[Dict[str, pd.Series]], + inst: str, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp, + direction: OrderDir, + decision: BaseTradeDecision, + trade_exchange: Exchange, + pa_config: dict = {}, + ): + """Get the base volume and price information""" + + agg = pa_config.get("agg", "twap").lower() + price = pa_config.get("price", "deal_price").lower() + + if price == "deal_price": + price_s = trade_exchange.get_deal_price( + inst, trade_start_time, trade_end_time, direction=direction, method=None + ) + else: + raise NotImplementedError(f"This type of input is not supported") + + # NOTE: there are some zeros in the trading price. These cases are known meaningless + # for aligning the previous logic, remove it. + # price_s = price_s.mask(np.isclose(price_s, 0)) + + if agg == "vwap": + volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) + elif agg == "twap": + volume_s = pd.Series(1, index=price_s.index) + else: + raise NotImplementedError(f"This type of input is not supported") + + # no sub executor on the lowest level + # So range_limit an total step will all be None + total_step = decision.total_step + if total_step is None: + total_step = 1 + range_limit = decision.get_range_limit(default_value=(0, total_step - 1)) + + assert volume_s.shape[0] % total_step == 0, "The price series can't be divided by step length" + factor = volume_s.shape[0] // total_step + + slc = slice(range_limit[0] * factor, (range_limit[1] + 1) * factor) + + volume_s = volume_s.iloc[slc] + price_s = price_s.iloc[slc] + + base_volume = volume_s.sum().item() + base_price = ((price_s * volume_s).sum() / base_volume).item() + + return base_price, base_volume + + def _agg_base_price( + self, + inner_order_indicators: List[Dict[str, pd.Series]], + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], trade_exchange: Exchange, pa_config: dict = {}, ): """ + # NOTE:!!!! + # Strong assumption!!!!!! + # the correctness of the base_price relies on that the **same** exchange is used Parameters ---------- inner_order_indicators : List[Dict[str, pd.Series]] the indicators of account of inner executor - trade_start_time : pd.Timestamp - the start_time of the trade period, for slicing - trade_end_time : pd.Timestamp - the end_time of the trade period, for slicing (so it may include more time at the end) + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], + a list of decisions according to inner_order_indicators trade_exchange : Exchange for retrieving trading price pa_config : dict @@ -362,32 +429,61 @@ class Indicator: "price": "$close", # TODO: this is not supported now!!!!! # default to use deal price of the exchange } + """ - agg = pa_config.get("agg", "twap").lower() - price = pa_config.get("price", "deal_price").lower() + # TODO: I think there are potentials to be optimized + trade_dir = self.order_indicator["trade_dir"] + if len(trade_dir) > 0: + bp_all, bv_all = [], [] + # + for oi, (dec, start, end) in zip(inner_order_indicators, decision_list): + bp_s = oi.get("base_price", pd.Series()).reindex(trade_dir.index) + bv_s = oi.get("base_volume", pd.Series()).reindex(trade_dir.index) + bp_new, bv_new = {}, {} + for pr, v, (inst, direction) in zip(bp_s.values, bv_s.values, trade_dir.items()): + if np.isnan(pr): + bp_new[inst], bv_new[inst] = self._get_base_vol_pri( + inst, + start, + end, + decision=dec, + direction=direction, + trade_exchange=trade_exchange, + pa_config=pa_config, + ) + else: + bp_new[inst], bv_new[inst] = pr, v - base_price = {} - for inst, dir in self.order_indicator["trade_dir"].items(): + bp_new, bv_new = pd.Series(bp_new), pd.Series(bv_new) + bp_all.append(bp_new) + bv_all.append(bv_new) + bp_all = pd.concat(bp_all, axis=1) + bv_all = pd.concat(bv_all, axis=1) - if price == "deal_price": - price_s = trade_exchange.get_deal_price(inst, trade_start_time, trade_end_time, dir, method=None) - else: - raise NotImplementedError(f"This type of input is not supported") + self.order_indicator["base_volume"] = bv_all.sum(axis=1) + self.order_indicator["base_price"] = (bp_all * bv_all).sum(axis=1) / self.order_indicator["base_volume"] - # there are some zeros in the trading price. These cases are known meaningless - price_s = price_s.mask(np.isclose(price_s, 0)) + def _agg_order_price_advantage(self): + if not self.order_indicator["trade_price"].empty: + self.order_indicator["pa"] = self.order_indicator["trade_price"] / self.order_indicator["base_price"] - 1 + else: + self.order_indicator["pa"] = pd.Series() - if agg == "vwap": - volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) - base_price[inst] = ((price_s * volume_s).sum() / volume_s.sum()).item() - elif agg == "twap": - base_price[inst] = price_s.mean().item() - - base_price = pd.Series(base_price) - - # update PA - self.order_indicator["pa"] = self.order_indicator["trade_price"] / base_price - 1 + def agg_order_indicators( + self, + inner_order_indicators: List[Dict[str, pd.Series]], + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], + outer_trade_decision: BaseTradeDecision, + trade_exchange: Exchange, + indicator_config={}, + ): + self._agg_order_trade_info(inner_order_indicators) + self._update_trade_amount(outer_trade_decision) + self._agg_order_fulfill_rate() + pa_config = indicator_config.get("pa_config", {}) + self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) + self._agg_order_price_advantage() def _cal_trade_fulfill_rate(self, method="mean"): if method == "mean": @@ -402,7 +498,7 @@ class Indicator: raise ValueError(f"method {method} is not supported!") def _cal_trade_price_advantage(self, method="mean"): - pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1) + pa_order = self.order_indicator["pa"] * (1 - self.order_indicator["trade_dir"] * 2) if method == "mean": return pa_order.mean() elif method == "amount_weighted": @@ -427,28 +523,6 @@ class Indicator: def _cal_trade_order_count(self): return self.order_indicator["amount"].count() - def update_order_indicators(self, trade_info: list): - self._update_order_trade_info(trade_info=trade_info) - self._update_order_fulfill_rate() - self._update_order_price_advantage() - - def agg_order_indicators( - self, - trade_start_time, - trade_end_time, - inner_order_indicators: List[Dict[str, pd.Series]], - outer_trade_decision: BaseTradeDecision, - trade_exchange: Exchange, - indicator_config={}, - ): - self._agg_order_trade_info(inner_order_indicators) - self._update_trade_amount(outer_trade_decision) - self._agg_order_fulfill_rate() - pa_config = indicator_config.get("pa_config", {}) - self._agg_order_price_advantage( - inner_order_indicators, trade_start_time, trade_end_time, trade_exchange, pa_config=pa_config - ) - def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}): show_indicator = indicator_config.get("show_indicator", False) ffr_config = indicator_config.get("ffr_config", {}) diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 0ba607bdb..5c643df30 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -1,9 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations +from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set + +if TYPE_CHECKING: + from qlib.backtest.order import BaseTradeDecision + from qlib.strategy.base import BaseStrategy import pandas as pd import warnings -from typing import Tuple, Union, List, Set from ..utils.resam import get_resam_calendar from ..data.data import Cal @@ -30,17 +35,20 @@ class TradeCalendarManager: closed end of the trade time range, by default None If `end_time` is None, it must be reset before trading. """ - self.freq = freq - self.start_time = pd.Timestamp(start_time) if start_time else None - self.end_time = pd.Timestamp(end_time) if end_time else None - self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time) + self.reset(freq=freq, start_time=start_time, end_time=end_time) - def _init_trade_calendar(self, freq, start_time, end_time): + def reset(self, freq, start_time, end_time): """ + Please refer to the docs of `__init__` + Reset the trade calendar - self.trade_len : The total count for trading step - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1] """ + self.freq = freq + self.start_time = pd.Timestamp(start_time) if start_time else None + self.end_time = pd.Timestamp(end_time) if end_time else None + _calendar, freq, freq_sam = get_resam_calendar(freq=freq) self._calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) @@ -67,6 +75,7 @@ class TradeCalendarManager: return self.freq def get_trade_len(self): + """get the total step length""" return self.trade_len def get_trade_step(self): @@ -99,6 +108,12 @@ class TradeCalendarManager: calendar_index = self.start_index + trade_step return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1) + def get_cur_step_time(self): + """ + get current step time + """ + return self.get_step_time(self.get_trade_step()) + def get_all_time(self): """Get the start_time and end_time for trading""" return self.start_time, self.end_time @@ -146,5 +161,40 @@ class CommonInfrastructure(BaseInfrastructure): class LevelInfrastructure(BaseInfrastructure): + """level instrastructure is created by executor, and then shared to strategies on the same level""" + def get_support_infra(self): - return ["trade_calendar"] + return ["trade_calendar", "sub_level_infra"] + + def reset_cal(self, freq, start_time, end_time): + """reset trade calendar manager""" + if self.has("trade_calendar"): + self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time) + else: + self.reset_infra(trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time)) + + def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure): + """this will make the calendar access easier when acrossing multi-levels""" + self.reset_infra(sub_level_infra=sub_level_infra) + + +def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]: + """ + A helper function for getting the decision-level index range limitation for inner strategy + - NOTE: this function is not applicable to order-level + + Parameters + ---------- + trade_calendar : TradeCalendarManager + outer_trade_decision : BaseTradeDecision + the trade decision made by outer strategy + + Returns + ------- + Union[int, int]: + start index and end index + """ + try: + return outer_trade_decision.get_range_limit() + except NotImplementedError: + return 0, trade_calendar.get_trade_len() - 1 diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 3ca325bf6..026afc8bb 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -14,29 +14,7 @@ from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO from ...backtest.exchange import Exchange, OrderHelper from ...backtest.utils import CommonInfrastructure, LevelInfrastructure from qlib.utils.file import get_io_object - - -def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]: - """ - A helper function for getting the decision-level index range limitation for inner strategy - - NOTE: this function is not applicable to order-level - - Parameters - ---------- - strategy : BaseStrategy - the inner strawtegy - outer_trade_decision : BaseTradeDecision - the trade decision made by outer strategy - - Returns - ------- - Union[int, int]: - start index and end index - """ - try: - return outer_trade_decision.get_range_limit() - except NotImplementedError: - return 0, strategy.trade_calendar.get_trade_len() - 1 +from qlib.backtest.utils import get_start_end_idx class TWAPStrategy(BaseStrategy): @@ -105,7 +83,7 @@ class TWAPStrategy(BaseStrategy): # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step - start_idx, end_idx = get_start_end_idx(self, self.outer_trade_decision) + start_idx, end_idx = get_start_end_idx(self.trade_calendar, self.outer_trade_decision) trade_len = end_idx - start_idx + 1 if trade_step < start_idx or trade_step > end_idx: diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index a787c098f..23d6b520a 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from qlib.backtest.position import BasePosition from typing import List, Union from ..model.base import BaseModel @@ -37,24 +38,26 @@ class BaseStrategy: self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) + @property + def trade_calendar(self) -> TradeCalendarManager: + return self.level_infra.get("trade_calendar") + + @property + def trade_position(self) -> BasePosition: + return self.common_infra.get("trade_account").current + def reset_level_infra(self, level_infra: LevelInfrastructure): if not hasattr(self, "level_infra"): self.level_infra = level_infra else: self.level_infra.update(level_infra) - if level_infra.has("trade_calendar"): - self.trade_calendar: TradeCalendarManager = level_infra.get("trade_calendar") - def reset_common_infra(self, common_infra: CommonInfrastructure): if not hasattr(self, "common_infra"): self.common_infra: CommonInfrastructure = common_infra else: self.common_infra.update(common_infra) - if common_infra.has("trade_account"): - self.trade_position = common_infra.get("trade_account").current - def reset( self, level_infra: LevelInfrastructure = None,