diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index b2c3f8c09..bc7210259 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations import copy -from typing import Union, TYPE_CHECKING +from typing import List, Tuple, Union, TYPE_CHECKING from .account import Account @@ -35,7 +35,7 @@ def get_exchange( min_cost=5.0, trade_unit=None, limit_threshold=None, - deal_price=None, + deal_price: Union[str, Tuple[str], List[str]] = None, ): """get_exchange @@ -54,8 +54,15 @@ def get_exchange( min transaction cost. trade_unit : int 100 for China A. - deal_price: str - dealing price type: 'close', 'open', 'vwap'. + deal_price: Union[str, Tuple[str], List[str]] + The `deal_price` supports following two types of input + - : str + - (, ): Tuple[str] or List[str] + + , or := + := str + - for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend + "$" to the expression) limit_threshold : float limit move 0.1 (10%) for example, long and short with same limit. @@ -69,13 +76,8 @@ def get_exchange( trade_unit = C.trade_unit if limit_threshold is None: limit_threshold = C.limit_threshold - if deal_price is None: - deal_price = C.deal_price if exchange is None: logger.info("Create new exchange") - # handle exception for deal_price - if deal_price[0] != "$": - deal_price = "$" + deal_price exchange = Exchange( freq=freq, diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index ccd5f4b45..9d4c96f48 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -4,7 +4,7 @@ import random import logging -from typing import Union +from typing import List, Tuple, Union import numpy as np import pandas as pd @@ -24,7 +24,7 @@ class Exchange: start_time=None, end_time=None, codes="all", - deal_price=None, + deal_price: Union[str, Tuple[str], List[str]] = None, subscribe_fields=[], limit_threshold=None, volume_threshold=None, @@ -40,7 +40,17 @@ class Exchange: :param start_time: closed start time for backtest :param end_time: closed end time for backtest :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) - :param deal_price: str, 'close', 'open', 'vwap' + + :param deal_price: Union[str, Tuple[str], List[str]] + The `deal_price` supports following two types of input + - : str + - (, ): Tuple[str] or List[str] + + , or := + := str + - for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend + "$" to the expression) + :param subscribe_fields: list, subscribe fields :param limit_threshold: float, 0.1 for example, default None :param volume_threshold: float, 0.1 for example, default None @@ -86,10 +96,15 @@ class Exchange: if C.region == REG_CN: self.logger.warning(f"limit_threshold may not be set to a reasonable value") - if deal_price[0] != "$": - self.deal_price = "$" + deal_price + if isinstance(deal_price, str): + if deal_price[0] != "$": + deal_price = "$" + deal_price + self.buy_price = self.sell_price = deal_price + elif isinstance(deal_price, (tuple, list)): + self.buy_price, self.sell_price = deal_price else: - self.deal_price = deal_price + raise NotImplementedError(f"This type of input is not supported") + if isinstance(codes, str): codes = D.instruments(codes) self.codes = codes @@ -98,7 +113,7 @@ class Exchange: # $factor is for rounding to the trading unit # $change is for calculating the limit of the stock - necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"} + necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} subscribe_fields = list(necessary_fields | set(subscribe_fields)) all_fields = list(necessary_fields | set(subscribe_fields)) self.all_fields = all_fields @@ -118,8 +133,10 @@ class Exchange: ) self.quote.columns = self.all_fields - if self.quote[self.deal_price].isna().any(): - self.logger.warning("{} field data contains nan.".format(self.deal_price)) + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if self.quote[pstr].isna().any(): + self.logger.warning("{} field data contains nan.".format(pstr)) if self.quote["$factor"].isna().any(): # The 'factor.day.bin' file not exists, and `factor` field contains `nan` @@ -148,9 +165,11 @@ class Exchange: # process extra_quote if "$close" not in self.extra_quote: raise ValueError("$close is necessray in extra_quote") - if self.deal_price not in self.extra_quote.columns: - self.extra_quote[self.deal_price] = self.extra_quote["$close"] - self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.") + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if pstr not in self.extra_quote.columns: + self.extra_quote[pstr] = self.extra_quote["$close"] + self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") if "$factor" not in self.extra_quote.columns: self.extra_quote["$factor"] = 1.0 self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") @@ -241,7 +260,7 @@ class Exchange: if trade_account is not None and position is not None: raise ValueError("trade_account and position can only choose one") - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) + trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction) # NOTE: order will be changed in this function trade_val, trade_cost = self._calc_trade_info_by_order( order, trade_account.current if trade_account else position @@ -266,12 +285,16 @@ class Exchange: def get_volume(self, stock_id, start_time, end_time): return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum") - def get_deal_price(self, stock_id, start_time, end_time): - deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last) + def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir): + if direction == OrderDir.SELL: + pstr = self.sell_price + elif direction == OrderDir.BUY: + pstr = self.buy_price + else: + raise NotImplementedError(f"This type of input is not supported") + deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=ts_data_last) if np.isclose(deal_price, 0.0) or np.isnan(deal_price): - self.logger.warning( - f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!" - ) + self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") deal_price = self.get_close(stock_id, start_time, end_time) return deal_price @@ -288,7 +311,9 @@ class Exchange: return None return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last) - def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): + def generate_amount_position_from_weight_position( + self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY + ): """ The generate the target position according to the weight and the cash. NOTE: All the cash will assigned to the tadable stock. @@ -297,7 +322,10 @@ class Exchange: weight_position : dict {stock_id : weight}; allocate cash by weight_position among then, weight must be in this range: 0 < weight < 1 cash : cash - trade_date : trade date + start_time : the start time point of the step + end_time : the end time point of the step + direction : the direction of the deal price for estimating the amount + # NOTE: this function is used for calculating target position. So the default direction is buy """ # calculate the total weight of tradable value @@ -324,7 +352,9 @@ class Exchange: cash * weight_position[stock_id] / tradable_weight - // self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) + // self.get_deal_price( + stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction + ) ) return amount_dict @@ -414,10 +444,16 @@ class Exchange: # return order_list : buy + sell return sell_order_list + buy_order_list - def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False): + def calculate_amount_position_value( + self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL + ): """Parameter position : Position() amount_dict : {stock_id : amount} + direction : the direction of the deal price for estimating the amount + # NOTE: + This function is used for calculating current position value. + So the default direction is sell. """ value = 0 for stock_id in amount_dict: @@ -426,7 +462,9 @@ class Exchange: and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False ): value += ( - self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) + self.get_deal_price( + stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction + ) * amount_dict[stock_id] ) return value @@ -466,7 +504,7 @@ class Exchange: :return: trade_val, trade_cost """ - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) + trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) if order.direction == Order.SELL: # sell if position is not None: diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 67ba4c5bc..e2a79db27 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -6,7 +6,7 @@ import pandas as pd from ...utils.resam import resam_ts_data from ...strategy.base import ModelStrategy -from ...backtest.order import Order, BaseTradeDecision, TradeDecisionWO +from ...backtest.order import Order, BaseTradeDecision, OrderDir, TradeDecisionWO from .order_generator import OrderGenWInteract @@ -236,7 +236,7 @@ class TopkDropoutStrategy(ModelStrategy): continue # buy order buy_price = self.trade_exchange.get_deal_price( - stock_id=code, start_time=trade_start_time, end_time=trade_end_time + stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY ) buy_amount = value / buy_price factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)