diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 1babd08c7..948af670a 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -96,7 +96,7 @@ def get_exchange( def create_account_instance( - start_time, end_time, benchmark: str, account: Union[float, int, Position], pos_type: str = "Position" + start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position" ) -> Account: """ # TODO: is very strange pass benchmark_config in the account(maybe for report) @@ -110,19 +110,23 @@ def create_account_instance( end time of the benchmark benchmark : str the benchmark for reporting - account : Union[float, int, Position] + account : Union[float, int, {"cash": float, "stock1": {"amount": int, "price"(optional): float}, "stock2": {"amount": int}}] information for describing how to creating the account For `float` or `int`: Using Account with only initial cash - For `Position`: - Using Account with a Position + For `dict`: + key "cash" means initial cash. + key "stock1" means the first stock information with amount and price(optional). + ... """ if isinstance(account, (int, float)): pos_kwargs = {"init_cash": account} - elif isinstance(account, Position): + elif isinstance(account, dict): + init_cash = account["cash"] + del account["cash"] pos_kwargs = { - "init_cash": account.position["cash"], - "position_dict": account.position, + "init_cash": init_cash, + "position_dict": account, } else: raise ValueError("account must be in (int, float, Position)") diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 542c0fba2..cc984b061 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -100,7 +100,6 @@ class Account: "module_path": "qlib.backtest.position", } ) - self.accum_info = AccumulatedInfo() self.report = None self.positions = {} @@ -119,8 +118,11 @@ class Account: def reset_report(self, freq, benchmark_config): # portfolio related metrics if self.is_port_metr_enabled(): + self.accum_info = AccumulatedInfo() self.report = Report(freq, benchmark_config) self.positions = {} + # fill stock value + self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq) # trading related metrics(e.g. high-frequency trading) self.indicator = Indicator() @@ -309,6 +311,7 @@ class Account: self.update_current(trade_start_time, trade_end_time, trade_exchange) if self.is_port_metr_enabled(): # report is portfolio related analysis + print(trade_start_time, trade_end_time) self.update_report(trade_start_time, trade_end_time) # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():` diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 0aab35e67..d64af0172 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -394,9 +394,8 @@ class Exchange: if trade_account is not None and position is not None: raise ValueError("trade_account and position can only choose one") - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction) # NOTE: order will be changed in this function - trade_val, trade_cost = self._calc_trade_info_by_order( + trade_price, trade_val, trade_cost = self._calc_trade_info_by_order( order, trade_account.current if trade_account else position, dealt_order_amount ) if order.deal_amount > 1e-5: @@ -714,6 +713,63 @@ class Exchange: f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}" ) + def _cal_trade_amount_by_cash_limit(self, now_trade_amount, trade_price, order, position): + """return the real order amount after cash limit. + + Parameters + ---------- + now_trade_amount : float + trade_price : float + order : Order + position : Position + + Return + ---------- + float + the real order amount after cash limit. + """ + cash = position.get_cash() + trade_val = now_trade_amount * trade_price + if order.direction == Order.SELL: + if cash < trade_val * self.close_cost: + # The money is not enough + self.logger.debug(f"Order clipped due to cash limitation: {order}") + return self.round_amount_by_trade_unit(cash / self.close_cost, order.factor) + elif order.direction == Order.BUY: + if cash < trade_val * (1 + self.open_cost): + # The money is not enough + self.logger.debug(f"Order clipped due to cash limitation: {order}") + return self.round_amount_by_trade_unit(cash / (1 + self.open_cost) / trade_price, order.factor) + + # The money is enough + return self.round_amount_by_trade_unit(now_trade_amount, order.factor) + + def _cal_trade_amount_by_stock_limit(self, now_trade_amount, order, position): + """return the real order amount after stock amount limit. + + Parameters + ---------- + now_trade_amount : float + order : Order + position : Position + + Return + ---------- + float + the real order amount after stock amount limit. + """ + if order.direction == Order.SELL: + current_amount = position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0 + if np.isclose(now_trade_amount, current_amount): + # when selling last stock. The amount don't need rounding + return now_trade_amount + elif now_trade_amount > current_amount: + return self.round_amount_by_trade_unit(current_amount, order.factor) + else: + return self.round_amount_by_trade_unit(now_trade_amount, order.factor) + elif order.direction == Order.BUY: + return self.round_amount_by_trade_unit(now_trade_amount, order.factor) + def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount): """ Calculation of trade info @@ -731,16 +787,10 @@ class Exchange: if order.direction == Order.SELL: # sell if position is not None: - current_amount = ( - position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0 - ) - if np.isclose(order.amount, current_amount): - # when selling last stock. The amount don't need rounding - order.deal_amount = order.amount - elif order.amount > current_amount: - order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor) - else: - order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + now_trade_amount = order.amount + now_trade_amount = self._cal_trade_amount_by_stock_limit(now_trade_amount, order, position) + now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position) + order.deal_amount = now_trade_amount else: # TODO: We don't know current position. # We choose to sell all @@ -752,17 +802,9 @@ class Exchange: elif order.direction == Order.BUY: # buy if position is not None: - cash = position.get_cash() - trade_val = order.amount * trade_price - if cash < trade_val * (1 + self.open_cost): - # The money is not enough - order.deal_amount = self.round_amount_by_trade_unit( - cash / (1 + self.open_cost) / trade_price, order.factor - ) - self.logger.debug(f"Order clipped due to cash limitation: {order}") - else: - # THe money is enough - order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + now_trade_amount = order.amount + now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position) + order.deal_amount = now_trade_amount else: # Unknown amount of money. Just round the amount order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) @@ -773,7 +815,7 @@ class Exchange: else: raise NotImplementedError("order type {} error".format(order.type)) - return trade_val, trade_cost + return trade_price, trade_val, trade_cost def get_order_helper(self) -> OrderHelper: if not hasattr(self, "_order_helper"): diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index e4f1ab40c..6747d7a7a 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -256,37 +256,33 @@ class Position(BasePosition): # NOTE: The position dict must be copied!!! # Otherwise the initial value self.init_cash = cash - self.position = position_dict.copy() + self.init_stock_info = position_dict.copy() + self.position = self.init_stock_info.copy() self.position["cash"] = cash - self.position["now_account_value"] = self.calculate_value() - def _fill_stock_value( - self, position_dict: dict, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30 - ): + # If the stock price information is missing, the account value will not be calculated temporarily + try: + self.position["now_account_value"] = self.calculate_value() + except KeyError: + pass + + def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30): """fill the stock value by the close price of latest last_days from qlib. Parameters ---------- - position_dict : Dict[stock_id, {"amount": int, "price": float}] - initial holding stocks. start_time : the start time of backtest. last_days : int, optional the days to get the latest close price, by default 30. - - Return - ---------- - Dict[stock_id, {"amount": int, "price": float}] - initial holding stocks with filled price. """ - stock_list = [] - for stock in position_dict: - if ("price" not in position_dict[stock]) or (position_dict[stock]["price"] is None): + for stock in self.init_stock_info: + if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None): stock_list.append(stock) if len(stock_list) == 0: - return position_dict + return start_time = pd.Timestamp(start_time) # note that start time is 2020-01-01 00:00:00 if raw start time is "2020-01-01" @@ -298,11 +294,13 @@ class Position(BasePosition): price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict() if len(price_dict) < len(stock_list): - raise ValueError(f"there is no close price in qlib") + lack_stock = set(stock_list) - set(price_dict) + raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days") for stock in stock_list: - position_dict[stock]["price"] = price_dict[stock] - return position_dict + self.init_stock_info[stock]["price"] = price_dict[stock] + self.position.update(self.init_stock_info) + self.position["now_account_value"] = self.calculate_value() def _init_stock(self, stock_id, amount, price=None): """ diff --git a/qlib/utils/time.py b/qlib/utils/time.py index f61c825d2..c18d76b14 100644 --- a/qlib/utils/time.py +++ b/qlib/utils/time.py @@ -97,13 +97,13 @@ class Freq: return _count, _freq_format_dict[_freq] -cn_time = [ +CN_TIME = [ datetime.strptime("9:30", "%H:%M"), datetime.strptime("11:30", "%H:%M"), datetime.strptime("13:00", "%H:%M"), datetime.strptime("15:00", "%H:%M"), ] -us_time = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")] +US_TIME = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")] def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"): @@ -111,15 +111,15 @@ def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"): time_obj = datetime.strptime(time_obj, "%H:%M") if region == "cn": - if time_obj >= cn_time[0] and time_obj < cn_time[1]: - return int((time_obj - cn_time[0]).total_seconds() / 60) - elif time_obj >= cn_time[2] and time_obj < cn_time[3]: - return int((time_obj - cn_time[2]).total_seconds() / 60) + 120 + if time_obj >= CN_TIME[0] and time_obj < CN_TIME[1]: + return int((time_obj - CN_TIME[0]).total_seconds() / 60) + elif time_obj >= CN_TIME[2] and time_obj < CN_TIME[3]: + return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120 else: raise ValueError(f"{time_obj} is not the opening time of the {region} stock market") elif region == "us": - if time_obj >= us_time[0] and time_obj < us_time[1]: - return int((time_obj - us_time[0]).total_seconds() / 60) + if time_obj >= US_TIME[0] and time_obj < US_TIME[1]: + return int((time_obj - US_TIME[0]).total_seconds() / 60) else: raise ValueError(f"{time_obj} is not the opening time of the {region} stock market") else: