diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 806f88a96..13213c344 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -160,7 +160,7 @@ class Account: self.accum_info.add_return_value(profit) # note here do not consider cost def update_order(self, order, trade_val, cost, trade_price): - if not self.is_port_metr_enabled(): + if self.current.skip_update(): # TODO: supporting polymorphism for account # updating order for infinite position is meaningless return diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index edcd7baaf..4fb90ff1f 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. +from qlib.backtest.position import Position import random import logging from typing import List, Tuple, Union, Callable, Iterable @@ -281,6 +282,8 @@ class Exchange: """ Deal order when the actual transaction + the results section in `Order` will be changed. + :param order: Deal the order. :param trade_account: Trade account to be updated after dealing the order. :param position: position to be updated after dealing the order. @@ -343,6 +346,7 @@ class Exchange: `None`: if the stock is suspended `None` may be returned `float`: return factor if the factor exists """ + assert (start_time is not None and end_time is not None, "the time range must be given") if stock_id not in self.quote.get_all_stock(): return None return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last) @@ -505,20 +509,56 @@ class Exchange: ) return value - def get_amount_of_trade_unit(self, factor): + def _get_factor_or_raise_erorr(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None): + """Please refer to the docs of get_amount_of_trade_unit""" + if factor is None: + if stock_id is not None and start_time is not None and end_time is not None: + factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time) + else: + raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") + return factor + + def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None): + """ + get the trade unit of amount based on **factor** + + the factor can be given directly or calculated in given time range and stock id. + `factor` has higher priority than `stock_id`, `start_time` and `end_time` + + Parameters + ---------- + factor : float + the adjusted factor + stock_id : str + the id of the stock + start_time : + the start time of trading range + end_time : + the end time of trading range + """ if not self.trade_w_adj_price and self.trade_unit is not None: + factor = self._get_factor_or_raise_erorr( + factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + ) return self.trade_unit / factor else: return None - def round_amount_by_trade_unit(self, deal_amount, factor): + def round_amount_by_trade_unit( + self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None + ): """Parameter + Please refer to the docs of get_amount_of_trade_unit + deal_amount : float, adjusted amount factor : float, adjusted factor return : float, real amount """ if not self.trade_w_adj_price and self.trade_unit is not None: # the minimal amount is 1. Add 0.1 for solving precision problem. + factor = self._get_factor_or_raise_erorr( + factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + ) return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount @@ -529,7 +569,7 @@ class Exchange: else: return deal_amount - def _calc_trade_info_by_order(self, order, position): + def _calc_trade_info_by_order(self, order, position: Position): """ Calculation of trade info @@ -541,6 +581,7 @@ class Exchange: """ trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) + order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) if order.direction == Order.SELL: # sell if position is not None: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 78cdbe5e0..b05b73801 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -1,5 +1,6 @@ from abc import abstractclassmethod, abstractmethod import copy +from qlib.log import get_module_logger from types import GeneratorType from qlib.backtest.account import Account import warnings @@ -102,7 +103,10 @@ class BaseExecutor: self.track_data = track_data self._trade_exchange = trade_exchange self.level_infra = LevelInfrastructure() + self.level_infra.reset_infra(common_infra=common_infra) self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra) + if common_infra is None: + get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}") def reset_common_infra(self, common_infra): """ @@ -239,7 +243,7 @@ class BaseExecutor: # Some concrete executor don't have inner decisions res, kwargs = obj - trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time() # Account will not be changed in this function self.trade_account.update_bar_end( trade_start_time, @@ -332,7 +336,7 @@ class NestedExecutor(BaseExecutor): self.inner_strategy.reset_common_infra(common_infra) def _init_sub_trading(self, trade_decision): - trade_start_time, trade_end_time = self.trade_calendar.get_cur_step_time() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time() self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) sub_level_infra = self.inner_executor.get_level_infra() self.level_infra.set_sub_level_infra(sub_level_infra) @@ -379,8 +383,8 @@ class NestedExecutor(BaseExecutor): ) trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information - # NOTE sub_cal.get_cur_step_time() must be called before collect_data in case of step shifting - decision_list.append((_inner_trade_decision, *sub_cal.get_cur_step_time())) + # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting + decision_list.append((_inner_trade_decision, *sub_cal.get_step_time())) # NOTE: Trade Calendar will step forward in the follow line _inner_execute_result = yield from self.inner_executor.collect_data( @@ -478,7 +482,7 @@ class SimulatorExecutor(BaseExecutor): def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): - trade_start_time, _ = self.trade_calendar.get_cur_step_time() + trade_start_time, _ = self.trade_calendar.get_step_time() execute_result = [] for order in self._get_order_iterator(trade_decision): @@ -491,30 +495,22 @@ class SimulatorExecutor(BaseExecutor): execute_result.append((order, trade_val, trade_cost, trade_price)) if self.verbose: if order.direction == Order.SELL: # sell - print( - "[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( - trade_start_time, - order.stock_id, - trade_price, - order.amount, - order.deal_amount, - order.factor, - trade_val, - ) - ) + action = "sell" else: - print( - "[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( - trade_start_time, - order.stock_id, - trade_price, - order.amount, - order.deal_amount, - order.factor, - trade_val, - ) + action = "buy" + print( + "[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cach {:.2f}.".format( + trade_start_time, + action, + order.stock_id, + trade_price, + order.amount, + order.deal_amount, + order.factor, + trade_val, + self.trade_account.get_cash(), ) - + ) else: if self.verbose: print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id)) diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 84264d0a9..b99cdb8e3 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -3,7 +3,8 @@ # TODO: rename it with decision.py from __future__ import annotations from enum import IntEnum -from qlib.utils.time import concat_date_time +from qlib.data.data import Cal +from qlib.utils.time import concat_date_time, epsilon_change from qlib.log import get_module_logger # try to fix circular imports when enabling type hints @@ -41,16 +42,24 @@ class Order: presents the weight factor assigned in Exchange() """ + # 1) time invariant values + # - they are set by users and is time-invariant. stock_id: str - amount: float # `amount` is a non-negative value + amount: float # `amount` is a non-negative and adjusted value + direction: int + # 2) time variant values: + # - Users may want to set these values when using lower level APIs + # - If users don't, TradeDecisionWO will help users to set them # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time) start_time: pd.Timestamp end_time: pd.Timestamp - direction: int - factor: float + # 3) results + # - users should not care about these values + # - they are set by the backtest system after finishing the results. deal_amount: float = field(init=False) # `deal_amount` is a non-negative value + factor: float = field(init=False) # FIXME: # for compatible now. @@ -127,8 +136,8 @@ class OrderHelper: code: str, amount: float, direction: OrderDir, - start_time: Union[str, pd.Timestamp], - end_time: Union[str, pd.Timestamp], + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, ) -> Order: """ help to create a order @@ -143,9 +152,9 @@ class OrderHelper: **adjusted trading amount** direction : OrderDir trading direction - start_time : Union[str, pd.Timestamp] + start_time : Union[str, pd.Timestamp] (optional) The interval of the order which belongs to - end_time : Union[str, pd.Timestamp] + end_time : Union[str, pd.Timestamp] (optional) The interval of the order which belongs to Returns @@ -153,15 +162,17 @@ class OrderHelper: Order: The created order """ - start_time = pd.Timestamp(start_time) - end_time = pd.Timestamp(end_time) + if start_time is not None: + start_time = pd.Timestamp(start_time) + if end_time is not None: + end_time = pd.Timestamp(end_time) + # NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders return Order( stock_id=code, amount=amount, start_time=start_time, end_time=end_time, direction=direction, - factor=self.exchange.get_factor(code, start_time, end_time), ) @@ -291,6 +302,7 @@ class BaseTradeDecision: """ self.strategy = strategy + self.start_time, self.end_time = strategy.trade_calendar.get_step_time() self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading` if isinstance(trade_range, Tuple): # for Tuple[int, int] @@ -406,6 +418,62 @@ class BaseTradeDecision: _start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx) return _start_idx, _end_idx + def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = False) -> Tuple[int, int]: + """ + get the range limit based on data calendar + + NOTE: it is **total** range limit instead of a single step + + The following assumptions are made + 1) The frequency of the exchange in common_infra is the same as the data calendar + 2) Users want the index mod by **day** (i.e. 240 min) + + Parameters + ---------- + rtype: str + - "full": return the full limitation of the deicsion in the day + - "step": return the limitation of current step + + raise_error: bool + True: raise error if no trade_range is set + False: return full trade calendar. + + It is useful in following cases + - users want to follow the order specific trading time range when decision level trade range is not + available. Raising NotImplementedError to indicates that range limit is not available + + Returns + ------- + Tuple[int, int]: + the range limit in data calendar + + Raises + ------ + NotImplementedError: + If the following criteria meet + 1) the decision can't provide a unified start and end + 2) raise_error is True + """ + # potential performance issue + day_start = pd.Timestamp(self.start_time.date()) + day_end = epsilon_change(day_start + pd.Timedelta(days=1)) + freq = self.strategy.trade_exchange.freq + _, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq) + if self.trade_range is None: + if raise_error: + raise NotImplementedError(f"There is no trade_range in this case") + else: + return 0, day_end_idx - day_start_idx + else: + if rtype == "full": + val_start, val_end = self.trade_range.clip_time_range(day_start, day_end) + elif rtype == "step": + val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time) + else: + raise ValueError(f"This type of input {rtype} is not supported") + _, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq) + return start_idx - day_start_idx, end_index - day_start_idx + def empty(self) -> bool: for obj in self.get_decision(): if isinstance(obj, Order): @@ -452,9 +520,15 @@ class TradeDecisionWO(BaseTradeDecision): def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None): super().__init__(strategy, trade_range=trade_range) self.order_list = order_list + start, end = strategy.trade_calendar.get_step_time() + for o in order_list: + if o.start_time is None: + o.start_time = start + if o.end_time is None: + o.end_time = end def get_decision(self) -> List[object]: return self.order_list def __repr__(self) -> str: - return f"strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]" + return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]" diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 98d8b4f63..13ad1eca2 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -351,7 +351,10 @@ class Indicator: trade_exchange: Exchange, pa_config: dict = {}, ): - """Get the base volume and price information""" + """ + Get the base volume and price information + All the base price values are rooted from this function + """ agg = pa_config.get("agg", "twap").lower() price = pa_config.get("price", "deal_price").lower() @@ -374,10 +377,12 @@ class Indicator: # NOTE: there are some zeros in the trading price. These cases are known meaningless # for aligning the previous logic, remove it. - # price_s = price_s.mask(np.isclose(price_s, 0)) + price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values. + # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8 if agg == "vwap": volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) + volume_s = volume_s.reindex(price_s.index) elif agg == "twap": volume_s = pd.Series(1, index=price_s.index) else: diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 60a49b0e2..b5ff84c54 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations import bisect +from qlib.utils.time import epsilon_change from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set if TYPE_CHECKING: @@ -22,7 +23,11 @@ class TradeCalendarManager: """ def __init__( - self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None + self, + freq: str, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + level_infra: "LevelInfrastructure" = None, ): """ Parameters @@ -36,6 +41,7 @@ class TradeCalendarManager: closed end of the trade time range, by default None If `end_time` is None, it must be reset before trading. """ + self.level_infra = level_infra self.reset(freq=freq, start_time=start_time, end_time=end_time) def reset(self, freq, start_time, end_time): @@ -82,19 +88,19 @@ class TradeCalendarManager: def get_trade_step(self): return self.trade_step - def get_step_time(self, trade_step=0, shift=0): + def get_step_time(self, trade_step=None, shift=0): """ Get the left and right endpoints of the trade_step'th trading interval About the endpoints: - Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc - - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib. - Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval. + # - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib. + # Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval. Parameters ---------- trade_step : int, optional - the number of trading step finished, by default 0 + the number of trading step finished, by default None to indicate current step shift : int, optional shift bars , by default 0 @@ -105,15 +111,43 @@ class TradeCalendarManager: - If shift > 0, return the trading time range of the earlier shift bars - If shift < 0, return the trading time range of the later shift bar """ + if trade_step is None: + trade_step = self.get_trade_step() trade_step = trade_step - shift calendar_index = self.start_index + trade_step - return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1) + return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1]) - def get_cur_step_time(self): + def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: """ - get current step time + get the calendar range + The following assumptions are made + 1) The frequency of the exchange in common_infra is the same as the data calendar + 2) Users want the **data index** mod by **day** (i.e. 240 min) + + Parameters + ---------- + rtype: str + - "full": return the full limitation of the deicsion in the day + - "step": return the limitation of current step + + Returns + ------- + Tuple[int, int]: """ - return self.get_step_time(self.get_trade_step()) + # potential performance issue + day_start = pd.Timestamp(self.start_time.date()) + day_end = epsilon_change(day_start + pd.Timedelta(days=1)) + freq = self.level_infra.get("common_infra").get("trade_exchange").freq + _, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq) + + if rtype == "full": + _, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq) + elif rtype == "step": + _, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq) + else: + raise ValueError(f"This type of input {rtype} is not supported") + + return start_idx - day_start_idx, end_index - day_start_idx def get_all_time(self): """Get the start_time and end_time for trading""" @@ -147,7 +181,7 @@ class TradeCalendarManager: return clip(left), clip(right) def __repr__(self) -> str: - return f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]" + return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]" class BaseInfrastructure: @@ -198,14 +232,16 @@ class LevelInfrastructure(BaseInfrastructure): sub_level_infra: - **NOTE**: this will only work after _init_sub_trading !!! """ - return ["trade_calendar", "sub_level_infra"] + return ["trade_calendar", "sub_level_infra", "common_infra"] def reset_cal(self, freq, start_time, end_time): """reset trade calendar manager""" if self.has("trade_calendar"): self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time) else: - self.reset_infra(trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time)) + self.reset_infra( + trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self) + ) def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure): """this will make the calendar access easier when acrossing multi-levels""" diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index f7728f911..e08039413 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -73,20 +73,20 @@ def indicator_analysis(df, method="mean"): Parameters ---------- df : pandas.DataFrame - columns: like ['pa', 'pos', 'ffr', 'amount', 'value']. + columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value']. Necessary fields: - 'pa' is the price advantage in trade indicators - 'pos' is the positive rate in trade indicators - 'ffr' is the fulfill rate in trade indicators Optional fields: - - 'amount' is the total deal amount, only necessary when method is 'amount_weighted' + - 'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted' - 'value' is the total trade value, only necessary when method is 'value_weighted' index: Index(datetime) method : str, optional statistics method of pa/ffr, by default "mean" - if method is 'mean', count the mean statistical value of each trade indicator - - if method is 'amount_weighted', count the amount weighted mean statistical value of each trade indicator + - if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator - if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator Note: statistics method of pos is always "mean" @@ -97,7 +97,7 @@ def indicator_analysis(df, method="mean"): """ weights_dict = { "mean": df["count"], - "amount_weighted": df["amount"].abs(), + "amount_weighted": df["deal_amount"].abs(), "value_weighted": df["value"].abs(), } if method not in weights_dict: diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index e2a79db27..48d96686a 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -64,7 +64,7 @@ class TopkDropoutStrategy(ModelStrategy): """ super(TopkDropoutStrategy, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs + model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) self.topk = topk self.n_drop = n_drop @@ -73,22 +73,6 @@ class TopkDropoutStrategy(ModelStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable - if trade_exchange is not None: - self.trade_exchange = trade_exchange - - def reset_common_infra(self, common_infra): - """ - Parameters - ---------- - common_infra : dict, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info - """ - super(TopkDropoutStrategy, self).reset_common_infra(common_infra) - - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -210,7 +194,6 @@ class TopkDropoutStrategy(ModelStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=Order.SELL, # 0 for sell, 1 for buy - factor=factor, ) # is order executable if self.trade_exchange.check_order(sell_order): @@ -247,7 +230,6 @@ class TopkDropoutStrategy(ModelStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=Order.BUY, # 1 for buy - factor=factor, ) buy_order_list.append(buy_order) return TradeDecisionWO(sell_order_list + buy_order_list, self) @@ -278,28 +260,12 @@ class WeightStrategyBase(ModelStrategy): - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ super(WeightStrategyBase, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs + model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj - if trade_exchange is not None: - self.trade_exchange = trade_exchange - - def reset_common_infra(self, common_infra): - """ - Parameters - ---------- - common_infra : dict, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info - """ - super(WeightStrategyBase, self).reset_common_infra(common_infra) - - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") def get_risk_degree(self, trade_step=None): """get_risk_degree diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 970734df5..36059f5a0 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -20,48 +20,6 @@ from qlib.backtest.utils import get_start_end_idx class TWAPStrategy(BaseStrategy): """TWAP Strategy for trading""" - def __init__( - self, - outer_trade_decision: BaseTradeDecision = None, - trade_exchange: Exchange = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - ): - """ - Parameters - ---------- - outer_trade_decision : BaseTradeDecision - the trade decision of outer strategy which this startegy relies - trade_exchange : Exchange - exchange that provides market info, used to deal order and generate report - - If `trade_exchange` is None, self.trade_exchange will be set with common_infra - - It allowes different trade_exchanges is used in different executions. - - For example: - - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster. - - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. - - """ - super(TWAPStrategy, self).__init__( - outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra - ) - - if trade_exchange is not None: - self.trade_exchange = trade_exchange - - def reset_common_infra(self, common_infra): - """ - Parameters - ---------- - common_infra : CommonInfrastructure, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info - """ - super(TWAPStrategy, self).reset_common_infra(common_infra) - - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ Parameters @@ -105,7 +63,9 @@ class TWAPStrategy(BaseStrategy): stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): continue - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) _order_amount = None # considering trade unit if _amount_trade_unit is None: @@ -141,7 +101,6 @@ class TWAPStrategy(BaseStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, # 1 for buy - factor=order.factor, ) order_list.append(_order) return TradeDecisionWO(order_list=order_list, strategy=self) @@ -161,46 +120,6 @@ class SBBStrategyBase(BaseStrategy): # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision - def __init__( - self, - outer_trade_decision: BaseTradeDecision = None, - trade_exchange: Exchange = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - ): - """ - Parameters - ---------- - outer_trade_decision : BaseTradeDecision - the trade decision of outer strategy which this startegy relies - trade_exchange : Exchange - exchange that provides market info, used to deal order and generate report - - If `trade_exchange` is None, self.trade_exchange will be set with common_infra - - It allowes different trade_exchanges is used in different executions. - - For example: - - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster. - - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. - """ - super(SBBStrategyBase, self).__init__( - outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra - ) - - if trade_exchange is not None: - self.trade_exchange = trade_exchange - - def reset_common_infra(self, common_infra): - """ - Parameters - ---------- - common_infra : dict, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info - """ - super(SBBStrategyBase, self).reset_common_infra(common_infra) - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ Parameters @@ -250,7 +169,9 @@ class SBBStrategyBase(BaseStrategy): self.trade_trend[order.stock_id] = _pred_trend continue # get amount of one trade unit - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if _pred_trend == self.TREND_MID: _order_amount = None # considering trade unit @@ -283,7 +204,6 @@ class SBBStrategyBase(BaseStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, - factor=order.factor, ) order_list.append(_order) @@ -330,7 +250,6 @@ class SBBStrategyBase(BaseStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, # 1 for buy - factor=order.factor, ) order_list.append(_order) else: @@ -349,7 +268,6 @@ class SBBStrategyBase(BaseStrategy): start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, # 1 for buy - factor=order.factor, ) order_list.append(_order) @@ -395,7 +313,9 @@ class SBBStrategyEMA(SBBStrategyBase): if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq - super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs) + super(SBBStrategyEMA, self).__init__( + outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs + ) def _reset_signal(self): trade_len = self.trade_calendar.get_trade_len() @@ -417,14 +337,8 @@ class SBBStrategyEMA(SBBStrategyBase): reset level-shared infra - After reset the trade calendar, the signal will be changed """ - if not hasattr(self, "level_infra"): - self.level_infra = level_infra - else: - self.level_infra.update(level_infra) - - if level_infra.has("trade_calendar"): - self.trade_calendar = level_infra.get("trade_calendar") - self._reset_signal() + super().reset_level_infra(level_infra) + self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): # if no signal, return mid trend @@ -484,10 +398,9 @@ class ACStrategy(BaseStrategy): if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq - super(ACStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) - - if trade_exchange is not None: - self.trade_exchange = trade_exchange + super(ACStrategy, self).__init__( + outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs + ) def _reset_signal(self): trade_len = self.trade_calendar.get_trade_len() @@ -506,33 +419,13 @@ class ACStrategy(BaseStrategy): for stock_id, stock_val in signal_df.groupby(level="instrument"): self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument") - def reset_common_infra(self, common_infra): - """ - Parameters - ---------- - common_infra : CommonInfrastructure, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info - """ - super(ACStrategy, self).reset_common_infra(common_infra) - - if common_infra.has("trade_exchange"): - self.trade_exchange = common_infra.get("trade_exchange") - def reset_level_infra(self, level_infra): """ reset level-shared infra - After reset the trade calendar, the signal will be changed """ - if not hasattr(self, "level_infra"): - self.level_infra = level_infra - else: - self.level_infra.update(level_infra) - - if level_infra.has("trade_calendar"): - self.trade_calendar = level_infra.get("trade_calendar") - self._reset_signal() + super().reset_level_infra(level_infra) + self._reset_signal() def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ @@ -578,7 +471,9 @@ class ACStrategy(BaseStrategy): if sig_sam is None or np.isnan(sig_sam): # no signal, TWAP - _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) + _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( + stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if _amount_trade_unit is None: # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) @@ -599,7 +494,9 @@ class ACStrategy(BaseStrategy): np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1)) ) / np.sinh(kappa * trade_len) _order_amount = order.amount * amount_ratio - _order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, order.factor) + _order_amount = self.trade_exchange.round_amount_by_trade_unit( + _order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + ) if order.direction == order.SELL: # sell all amount at last @@ -673,8 +570,6 @@ class RandomOrderStrategy(BaseStrategy): .create( code=stock_id, amount=volume * self.volume_ratio, - start_time=step_time_start, - end_time=step_time_end, direction=self.direction, ) ) @@ -734,9 +629,7 @@ class FileOrderStrategy(BaseStrategy): execute_result will be ignored in FileOrderStrategy """ oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() - tc = self.trade_calendar - step = tc.get_trade_step() - start, end = tc.get_step_time(step) + start, _ = self.trade_calendar.get_step_time() # CONVERSION: the bar is indexed by the time try: df = self.order_df.loc(axis=0)[start] @@ -750,8 +643,6 @@ class FileOrderStrategy(BaseStrategy): code=idx, amount=row["amount"], direction=Order.parse_dir(row["direction"]), - start_time=start, - end_time=end, ) ) return TradeDecisionWO(order_list, self, self.trade_range) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 23d6b520a..fa21fae5f 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition -from typing import List, Union +from typing import List, Tuple, Union from ..model.base import BaseModel from ..data.dataset import DatasetH @@ -22,6 +23,7 @@ class BaseStrategy: outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, + trade_exchange: Exchange = None, ): """ Parameters @@ -34,9 +36,18 @@ class BaseStrategy: level shared infrastructure for backtesting, including trade calendar common_infra : CommonInfrastructure, optional common infrastructure for backtesting, including trade_account, trade_exchange, .etc + + trade_exchange : Exchange + exchange that provides market info, used to deal order and generate report + - If `trade_exchange` is None, self.trade_exchange will be set with common_infra + - It allowes different trade_exchanges is used in different executions. + - For example: + - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster. + - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ - self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) + self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) + self._trade_exchange = trade_exchange @property def trade_calendar(self) -> TradeCalendarManager: @@ -46,6 +57,11 @@ class BaseStrategy: def trade_position(self) -> BasePosition: return self.common_infra.get("trade_account").current + @property + def trade_exchange(self) -> Exchange: + """get trade exchange in a prioritized order""" + return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange") + def reset_level_infra(self, level_infra: LevelInfrastructure): if not hasattr(self, "level_infra"): self.level_infra = level_infra @@ -69,6 +85,24 @@ class BaseStrategy: - reset `level_infra`, used to reset trade calendar, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc - reset `outer_trade_decision`, used to make split decision + + **NOTE**: + split this function into `reset` and `_reset` will make following cases more convenient + 1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called + when initialization + """ + self._reset( + level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs + ) + + def _reset( + self, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + outer_trade_decision=None, + ): + """ + Please refer to the docs of `reset` """ if level_infra is not None: self.reset_level_infra(level_infra) @@ -124,6 +158,36 @@ class BaseStrategy: # NOTE: normally, user should do something to the strategy due to the change of outer decision raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method") + # helper methods: not necessary but for convenience + def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]: + """ + return data calendar's available decision range for `self` strategy + the range consider following factors + - data calendar in the charge of `self` strategy + - trading range limitation from the decision of outer strategy + + + related methods + - TradeCalendarManager.get_data_cal_range + - BaseTradeDecision.get_data_cal_range_limit + + Parameters + ---------- + rtype: str + - "full": return the available data index range of the strategy from `start_time` to `end_time` + - "step": return the available data index range of the strategy of current step + + Returns + ------- + Tuple[int, int]: + the available range both sides are closed + """ + cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype) + if self.outer_trade_decision is None: + raise ValueError(f"There is not limitation for strategy {self}") + range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype) + return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) + class ModelStrategy(BaseStrategy): """Model-based trading strategy, use model to make predictions for trading""" diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 5900fb286..1cce56918 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -210,10 +210,13 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy the class object and it's arguments. """ if isinstance(config, dict): - module = get_module_by_module_path(config.get("module_path", default_module)) + if isinstance(config["class"], str): + module = get_module_by_module_path(config.get("module_path", default_module)) - # raise AttributeError - klass = getattr(module, config["class"]) + # raise AttributeError + klass = getattr(module, config["class"]) + else: + klass = config["class"] # the class type itself is passed in kwargs = config.get("kwargs", {}) elif isinstance(config, str): module = get_module_by_module_path(default_module) @@ -235,11 +238,17 @@ def init_instance_by_config( ---------- config : Union[str, dict, object] dict example. + case 1) { 'class': 'ClassName', 'kwargs': dict, # It is optional. {} will be used if not given 'model_path': path, # It is optional if module is given } + case 2) + { + 'class': , + 'kwargs': dict, # It is optional. {} will be used if not given + } str example. 1) specify a pickle object - path like 'file:////obj.pkl' diff --git a/qlib/utils/time.py b/qlib/utils/time.py index f4913dde4..e365de6d8 100644 --- a/qlib/utils/time.py +++ b/qlib/utils/time.py @@ -160,5 +160,32 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp: return concat_date_time(date, new_time) +def epsilon_change(datetime: pd.Timestamp, direction: str = "backward") -> pd.Timestamp: + """ + change the time by infinitely small quantity. + + + Parameters + ---------- + datetime : pd.Timestamp + the original time + direction : str + the direction the time are going to + - "backward" for going to history + - "forward" for going to the future + + Returns + ------- + pd.Timestamp: + the shifted time + """ + if direction == "backward": + return datetime - pd.Timedelta(seconds=1) + elif direction == "forward": + return datetime + pd.Timedelta(seconds=1) + else: + raise ValueError("Wrong input") + + if __name__ == "__main__": print(get_day_min_idx_range("8:30", "14:59", "10min"))