diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index bc7210259..6aa83e687 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -33,9 +33,9 @@ def get_exchange( open_cost=0.0015, close_cost=0.0025, min_cost=5.0, - trade_unit=None, limit_threshold=None, deal_price: Union[str, Tuple[str], List[str]] = None, + **kwargs, ): """get_exchange @@ -53,7 +53,7 @@ def get_exchange( min_cost : float min transaction cost. trade_unit : int - 100 for China A. + Included in kwargs. Please refer to the docs of `__init__` of `Exchange` deal_price: Union[str, Tuple[str], List[str]] The `deal_price` supports following two types of input - : str @@ -72,8 +72,6 @@ def get_exchange( an initialized Exchange object """ - if trade_unit is None: - trade_unit = C.trade_unit if limit_threshold is None: limit_threshold = C.limit_threshold if exchange is None: @@ -89,8 +87,8 @@ def get_exchange( limit_threshold=limit_threshold, open_cost=open_cost, close_cost=close_cost, - trade_unit=trade_unit, min_cost=min_cost, + **kwargs ) return exchange else: diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 9d4c96f48..26fae378f 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -30,9 +30,9 @@ class Exchange: volume_threshold=None, open_cost=0.0015, close_cost=0.0025, - trade_unit=None, min_cost=5, extra_quote=None, + **kwargs, ): """__init__ @@ -56,7 +56,11 @@ class Exchange: :param volume_threshold: float, 0.1 for example, default None :param open_cost: cost rate for open, default 0.0015 :param close_cost: cost rate for close, default 0.0025 - :param trade_unit: trade unit, 100 for China A market + :param trade_unit: trade unit, 100 for China A market. + None for disable trade unit. + **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must + distinguish `not set` and `disable trade_unit` + :param min_cost: min cost, default 5 :param extra_quote: pandas, dataframe consists of columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy']. @@ -77,8 +81,10 @@ class Exchange: self.start_time = start_time self.end_time = end_time - if trade_unit is None: - trade_unit = C.trade_unit + self.trade_unit = kwargs.pop("trade_unit", C.trade_unit) + if len(kwargs) > 0: + raise ValueError(f"Get Unexpected arguments {kwargs}") + if limit_threshold is None: limit_threshold = C.limit_threshold if deal_price is None: @@ -86,7 +92,6 @@ class Exchange: self.logger = get_module_logger("online operator", level=logging.INFO) - self.trade_unit = trade_unit # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. if limit_threshold is None: diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py index da52b0d53..8210e4809 100644 --- a/tests/backtest/test_file_strategy.py +++ b/tests/backtest/test_file_strategy.py @@ -62,6 +62,7 @@ class FileStrTest(TestAutoData): "close_cost": 0.0015, "min_cost": 5, "codes": codes, + "trade_unit": None, }, # "pos_type": "InfPosition" # Position with infinitive position }