diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index a4c20f730..0de290f02 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -92,7 +92,9 @@ def get_exchange( return init_instance_by_config(exchange, accept_types=Exchange) -def create_account_instance(start_time, end_time, benchmark: str, account: float, pos_type: str="Position") -> Account: +def create_account_instance( + start_time, end_time, benchmark: str, account: float, pos_type: str = "Position" +) -> Account: """ # TODO: is very strange pass benchmark_config in the account(maybe for report) # There should be a post-step to process the report. @@ -119,26 +121,25 @@ def create_account_instance(start_time, end_time, benchmark: str, account: float "start_time": start_time, "end_time": end_time, }, - "pos_type": pos_type + "pos_type": pos_type, } return Account(**kwargs) -def get_strategy_executor(start_time, - end_time, - strategy: BaseStrategy, - executor: BaseExecutor, - benchmark: str = "SH000300", - account: Union[float, str] = 1e9, - exchange_kwargs: dict = {}, - pos_type: str = "Position", - ): +def get_strategy_executor( + start_time, + end_time, + strategy: BaseStrategy, + executor: BaseExecutor, + benchmark: str = "SH000300", + account: Union[float, str] = 1e9, + exchange_kwargs: dict = {}, + pos_type: str = "Position", +): - trade_account = create_account_instance(start_time=start_time, - end_time=end_time, - benchmark=benchmark, - account=account, - pos_type=pos_type) + trade_account = create_account_instance( + start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type + ) exchange_kwargs = copy.copy(exchange_kwargs) if "start_time" not in exchange_kwargs: @@ -154,14 +155,16 @@ def get_strategy_executor(start_time, return trade_strategy, trade_executor -def backtest(start_time, - end_time, - strategy, - executor, - benchmark="SH000300", - account=1e9, - exchange_kwargs={}, - pos_type: str = "Position"): +def backtest( + start_time, + end_time, + strategy, + executor, + benchmark="SH000300", + account=1e9, + exchange_kwargs={}, + pos_type: str = "Position", +): trade_strategy, trade_executor = get_strategy_executor( start_time, @@ -178,14 +181,16 @@ def backtest(start_time, return report_dict, indicator_dict -def collect_data(start_time, - end_time, - strategy, - executor, - benchmark="SH000300", - account=1e9, - exchange_kwargs={}, - pos_type: str = "Position"): +def collect_data( + start_time, + end_time, + strategy, + executor, + benchmark="SH000300", + account=1e9, + exchange_kwargs={}, + pos_type: str = "Position", +): trade_strategy, trade_executor = get_strategy_executor( start_time, diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 64a814dba..a6ef2f6b8 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -63,7 +63,9 @@ class AccumulatedInfo: class Account: - def __init__(self, init_cash: float=1e9, freq: str = "day", benchmark_config: dict = {}, pos_type:str = "Position"): + def __init__( + self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position" + ): self.pos_type = pos_type self.init_vars(init_cash, freq, benchmark_config) @@ -71,13 +73,13 @@ class Account: # init cash self.init_cash = init_cash - self.current: BasePosition = init_instance_by_config({ - 'class': self.pos_type, - 'kwargs': { - "cash": init_cash - }, - 'module_path': "qlib.backtest.position", - }) + self.current: BasePosition = init_instance_by_config( + { + "class": self.pos_type, + "kwargs": {"cash": init_cash}, + "module_path": "qlib.backtest.position", + } + ) self.accum_info = AccumulatedInfo() self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 81395dc73..0ac4581da 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -23,7 +23,9 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec return return_value.get("report"), return_value.get("indicator") -def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None): +def collect_data_loop( + start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None +): """Generator for collecting the trade decision data for rl training Parameters @@ -68,7 +70,7 @@ def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_ } all_indicators = {} for _executor in all_executors: - key = "{}{}".format( *Freq.parse(_executor.time_per_step)) + key = "{}{}".format(*Freq.parse(_executor.time_per_step)) all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe() all_indicators[key + "_obj"] = _executor.get_trade_indicator() return_value.update({"report": all_reports, "indicator": all_indicators}) diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 6324a9be9..19ea807c1 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -2,8 +2,10 @@ # Licensed under the MIT License. # TODO: rename it with decision.py from __future__ import annotations + # try to fix circular imports when enabling type hints from typing import TYPE_CHECKING + if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy from qlib.backtest.utils import TradeCalendarManager @@ -59,6 +61,7 @@ class BaseTradeDecision: 1. The outer strategy's decision is available at the start of the interval 2. Same as `case 1.3` """ + def __init__(self, strategy: BaseStrategy): """ Parameters @@ -125,7 +128,8 @@ class TradeDecisionWO(BaseTradeDecision): Trade Decision (W)ith (O)rder. Besides, the time_range is also included. """ - def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple=None): + + def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple = None): super().__init__(strategy) self.order_list = order_list self.idx_range = idx_range @@ -198,8 +202,7 @@ class TradeDecisionWithOrderPool: class BaseDecisionUpdater: def update_decision(self, decision, trade_calendar) -> BaseTradeDecision: - """[summary] - + """ Parameters ---------- decision : BaseTradeDecision diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 70272f688..0f36e4959 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -15,7 +15,8 @@ class BasePosition: The Position want to maintain the position like a dictionary Please refer to the `Position` class for the position """ - def __init__(self, cash=0., *args, **kwargs) -> None: + + def __init__(self, cash=0.0, *args, **kwargs) -> None: pass def skip_update(self) -> bool: @@ -46,7 +47,6 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `check_stock` method") - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float): """ Parameters @@ -86,6 +86,7 @@ class BasePosition: the value(money) of all the stock """ raise NotImplementedError(f"Please implement the `calculate_stock_value` method") + def get_stock_list(self) -> List: """ Get the list of stocks in the position. @@ -140,7 +141,7 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method") - def get_stock_weight_dict(self, only_stock: bool=False) -> Dict: + def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: """ generate stock weight dict {stock_id : value weight of stock in the position} it is meaningful in the beginning or the end of each trade date @@ -399,13 +400,13 @@ class Position(BasePosition): self.position["now_account_value"] = now_account_value - class InfPosition(BasePosition): """ Position with infinite cash and amount. This is useful for generating random orders. """ + def skip_update(self) -> bool: """ Updating state is meaningless for InfPosition """ return True diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 3f2649839..f217ea169 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -18,7 +18,7 @@ from ..tests.config import CSI300_BENCH class Report: - ''' + """ Motivation: Report is for supporting portfolio related metrics. @@ -26,7 +26,8 @@ class Report: daily report of the account contain those followings: returns, costs turnovers, accounts, cash, bench, value update report - ''' + """ + def __init__(self, freq: str = "day", benchmark_config: dict = {}): """ Parameters diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 720eb627e..0ba607bdb 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -140,7 +140,6 @@ class BaseInfrastructure: self.reset_infra(**infra_dict) - class CommonInfrastructure(BaseInfrastructure): def get_support_infra(self): return ["trade_account", "trade_exchange"] diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 2e72cb32c..67ba4c5bc 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -15,6 +15,7 @@ class TopkDropoutStrategy(ModelStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision + # 3. Supporting checking the availability of trade decision def __init__( self, model, @@ -104,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy): pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: - return [] + return TradeDecisionWO([], self) if self.only_tradable: # If The strategy only consider tradable stock when make decision # It needs following actions to filter stocks @@ -256,6 +257,7 @@ class WeightStrategyBase(ModelStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision + # 3. Supporting checking the availability of trade decision def __init__( self, model, @@ -332,9 +334,9 @@ class WeightStrategyBase(ModelStrategy): pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: - return [] + return TradeDecisionWO([], self) current_temp = copy.deepcopy(self.trade_position) - assert(isinstance(current_temp, Position)) # Avoid InfPosition + assert isinstance(current_temp, Position) # Avoid InfPosition target_weight_position = self.generate_target_weight_position( score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index b8a900b85..0fb98e8ac 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -102,7 +102,7 @@ class TWAPStrategy(BaseStrategy): trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step start_idx, end_idx = get_start_end_idx(self, self.outer_trade_decision) - trade_len = end_idx - start_idx + 1 + trade_len = end_idx - start_idx + 1 if trade_step < start_idx: # It is not time to start trading @@ -137,12 +137,16 @@ class TWAPStrategy(BaseStrategy): # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - rel_trade_step) / (trade_len - rel_trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - rel_trade_step + 1)) _order_amount = ( - (trade_unit_cnt + trade_len - rel_trade_step - 1) // (trade_len - rel_trade_step) * _amount_trade_unit + (trade_unit_cnt + trade_len - rel_trade_step - 1) + // (trade_len - rel_trade_step) + * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or rel_trade_step == trade_len - 1): + if self.trade_amount[order.stock_id] > 1e-5 and ( + _order_amount < 1e-5 or rel_trade_step == trade_len - 1 + ): _order_amount = self.trade_amount[order.stock_id] _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) @@ -173,6 +177,7 @@ class SBBStrategyBase(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision + # 3. Supporting checking the availability of trade decision def __init__( self, @@ -225,8 +230,7 @@ class SBBStrategyBase(BaseStrategy): self.trade_trend = {} self.trade_amount = {} # init the trade amount of order and predicted trade trend - outer_order_generator = outer_trade_decision.generator() - for order in outer_order_generator: + for order in outer_trade_decision.get_decision(): self.trade_trend[order.stock_id] = self.TREND_MID self.trade_amount[order.stock_id] = order.amount @@ -248,8 +252,7 @@ class SBBStrategyBase(BaseStrategy): pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) order_list = [] # for each order in in self.outer_trade_decision - outer_order_generator = self.outer_trade_decision.generator(only_enable=True) - for order in outer_order_generator: + for order in self.outer_trade_decision.get_decision(): # get the price trend if trade_step % 2 == 0: # in the first of two adjacent bars, predict the price trend @@ -379,9 +382,11 @@ class SBBStrategyEMA(SBBStrategyBase): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal. """ + # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision + # 3. Supporting checking the availability of trade decision def __init__( self, @@ -463,6 +468,7 @@ class ACStrategy(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision + # 3. Supporting checking the availability of trade decision def __init__( self, lamb: float = 1e-6, @@ -555,8 +561,7 @@ class ACStrategy(BaseStrategy): if outer_trade_decision is not None: self.trade_amount = {} # init the trade amount of order and predicted trade trend - outer_order_generator = outer_trade_decision.generator() - for order in outer_order_generator: + for order in outer_trade_decision.get_decision(): self.trade_amount[order.stock_id] = order.amount def generate_trade_decision(self, execute_result=None): @@ -564,8 +569,6 @@ class ACStrategy(BaseStrategy): trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step trade_len = self.trade_calendar.get_trade_len() - # update outer trade decision - self.outer_trade_decision.update(self.trade_calendar) # update the order amount if execute_result is not None: @@ -575,8 +578,7 @@ class ACStrategy(BaseStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) order_list = [] - outer_order_generator = self.outer_trade_decision.generator(only_enable=True) - for order in outer_order_generator: + for order in self.outer_trade_decision.get_decision(): # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time @@ -638,14 +640,16 @@ class ACStrategy(BaseStrategy): class RandomOrderStrategy(BaseStrategy): - - def __init__(self, - index_range: Tuple[int, int], # The range is closed on both left and right. - sample_ratio: float = 1., - volume_ratio: float = 0.01, - market: str = "all", - *args, - **kwargs): + def __init__( + self, + index_range: Tuple[int, int], # The range is closed on both left and right. + sample_ratio: float = 1.0, + volume_ratio: float = 0.01, + market: str = "all", + direction: int = Order.BUY, + *args, + **kwargs, + ): """ Parameters ---------- @@ -667,9 +671,12 @@ class RandomOrderStrategy(BaseStrategy): self.sample_ratio = sample_ratio self.volume_ratio = volume_ratio self.market = market + self.direction = direction exch: Exchange = self.common_infra.get("trade_exchange") # TODO: this can't be online - self.volume = D.features(D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time) + self.volume = D.features( + D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time + ) self.volume_df = self.volume.iloc[:, 0].unstack() def generate_trade_decision(self, execute_result=None): @@ -677,15 +684,15 @@ class RandomOrderStrategy(BaseStrategy): step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step) order_list = [] - for direction in Order.SELL, Order.BUY: - if step_time_start in self.volume_df: - for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): - order_list.append( - self.common_infra.get("trade_exchange").create_order( - code=stock_id, - amount=volume * self.volume_ratio, - start_time=step_time_start, - end_time=step_time_end, - direction=direction, # 1 for buy - )) + if step_time_start in self.volume_df: + for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): + order_list.append( + self.common_infra.get("trade_exchange").create_order( + code=stock_id, + amount=volume * self.volume_ratio, + start_time=step_time_start, + end_time=step_time_end, + direction=self.direction, + ) + ) return TradeDecisionWO(order_list, self, self.index_range) diff --git a/qlib/data/data.py b/qlib/data/data.py index 116861e78..d6735b4e6 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -213,7 +213,7 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin): self.backend = kwargs.get("backend", {}) @staticmethod - def instruments(market: Union[List, str]="all", filter_pipe: Union[List, None]=None): + def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None): """Get the general config dictionary for a base market adding several dynamic filters. Parameters diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 734d25721..c8a326e80 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -85,7 +85,9 @@ class BaseStrategy: """ raise NotImplementedError("generate_trade_decision is not implemented!") - def update_trade_decision(self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]: + def update_trade_decision( + self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager + ) -> Union[BaseTradeDecision, None]: """ update trade decision in each step of inner execution, this method enable all order diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 76d97e1bc..4df155946 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -9,6 +9,7 @@ from . import lazy_sort_index from ..config import C from .time import Freq, cal_sam_minute + def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: """ Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam diff --git a/qlib/utils/time.py b/qlib/utils/time.py index fb37fd0a4..bfbdb9f1f 100644 --- a/qlib/utils/time.py +++ b/qlib/utils/time.py @@ -14,7 +14,7 @@ import functools @functools.lru_cache(maxsize=240) -def get_min_cal(shift: int=0) -> List[time]: +def get_min_cal(shift: int = 0) -> List[time]: """ get the minute level calendar in day period @@ -30,8 +30,9 @@ def get_min_cal(shift: int=0) -> List[time]: """ cal = [] - for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) +\ - list(pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift)): + for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) + list( + pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift) + ): cal.append(ts.time()) return cal @@ -115,7 +116,7 @@ def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]: start = pd.Timestamp(start).time() end = pd.Timestamp(end).time() freq = Freq(freq) - in_day_cal = Freq.MIN_CAL[::freq.count] + in_day_cal = Freq.MIN_CAL[:: freq.count] left_idx = bisect.bisect_left(in_day_cal, start) right_idx = bisect.bisect_right(in_day_cal, end) - 1 return left_idx, right_idx @@ -141,15 +142,19 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp: """ cal = get_min_cal(C.min_data_shift)[::sam_minutes] idx = bisect.bisect_right(cal, x.time()) - 1 - date, new_time = x.date(), cal[idx] + date, new_time = x.date(), cal[idx] return pd.Timestamp( - datetime(date.year, - month=date.month, - day=date.day, - hour=new_time.hour, - minute=new_time.minute, - second=new_time.second, - microsecond=new_time.microsecond)) + datetime( + date.year, + month=date.month, + day=date.day, + hour=new_time.hour, + minute=new_time.minute, + second=new_time.second, + microsecond=new_time.microsecond, + ) + ) + if __name__ == "__main__": print(get_day_min_idx_range("8:30", "14:59", "10min"))