diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 13213c344..03e51c740 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -73,6 +73,18 @@ class Account: pos_type: str = "Position", port_metr_enabled: bool = True, ): + """the trade account of backtest. + + Parameters + ---------- + init_cash : float, optional + initial cash, by default 1e9 + position_dict : Dict[stock_id, {"amount": int, "price"(optional): float}], optional + initial stocks with amount and price, + if there is no price key in the dict of stocks, it will be filled by latest close price from qlib. + by default {}. + """ + self._pos_type = pos_type self._port_metr_enabled = port_metr_enabled self.init_vars(init_cash, position_dict, freq, benchmark_config) @@ -93,6 +105,8 @@ class Account: "kwargs": { "cash": init_cash, "position_dict": position_dict, + "start_time": benchmark_config["start_time"], + "freq": freq, }, "module_path": "qlib.backtest.position", } diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 7c32edc81..92b66a342 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -4,10 +4,14 @@ import copy import pathlib -from typing import Dict, List +from typing import Dict, List, Union + import pandas as pd +from datetime import timedelta import numpy as np + from .order import Order +from ..data.data import D class BasePosition: @@ -199,14 +203,72 @@ class Position(BasePosition): } """ - def __init__(self, cash=0, position_dict={}): + def __init__(self, start_time, freq, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}): + """Init position by cash and position_dict. + + Parameters + ---------- + start_time : + the start time of backtest. It's for filling the initial value of stocks. + cash : float, optional + initial cash in account, by default 0 + position_dict : Dict[stock_id, {"amount": int, "price"(optional): float}], optional + initial stocks with parameters amount and price, + if there is no price key in the dict of stocks, it will be filled by _fill_stock_value. + by default {}. + """ + # NOTE: The position dict must be copied!!! # Otherwise the initial value self.init_cash = cash - self.position = position_dict.copy() + self.position = self._fill_stock_value(position_dict.copy(), start_time, freq) self.position["cash"] = cash self.position["now_account_value"] = self.calculate_value() + def _fill_stock_value( + self, position_dict: dict, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30 + ): + """fill the stock value by the close price of latest last_days from qlib. + + Parameters + ---------- + position_dict : Dict[stock_id, {"amount": int, "price": float}] + initial holding stocks. + start_time : + the start time of backtest. + last_days : int, optional + the days to get the latest close price, by default 30. + + Return + ---------- + Dict[stock_id, {"amount": int, "price": float}] + initial holding stocks with filled price. + """ + + stock_list = [] + for stock in position_dict: + if ("price" not in position_dict[stock]) or (position_dict[stock]["price"] is None): + stock_list.append(stock) + + if len(stock_list) == 0: + return position_dict + + start_time = pd.Timestamp(start_time) + # note that start time is 2020-01-01 00:00:00 if raw start time is "2020-01-01" + price_end_time = start_time + price_start_time = start_time - timedelta(days=last_days) + price_df = D.features( + stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True + ).dropna() + price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict() + + if len(price_dict) < len(stock_list): + raise ValueError(f"there is no close price in qlib") + + for stock in stock_list: + position_dict[stock]["price"] = price_dict[stock] + return position_dict + def _init_stock(self, stock_id, amount, price=None): """ initialization the stock in current position