mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
update position and negative cash
This commit is contained in:
@@ -96,7 +96,7 @@ def get_exchange(
|
||||
|
||||
|
||||
def create_account_instance(
|
||||
start_time, end_time, benchmark: str, account: Union[float, int, Position], pos_type: str = "Position"
|
||||
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
|
||||
) -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
@@ -110,19 +110,23 @@ def create_account_instance(
|
||||
end time of the benchmark
|
||||
benchmark : str
|
||||
the benchmark for reporting
|
||||
account : Union[float, int, Position]
|
||||
account : Union[float, int, {"cash": float, "stock1": {"amount": int, "price"(optional): float}, "stock2": {"amount": int}}]
|
||||
information for describing how to creating the account
|
||||
For `float` or `int`:
|
||||
Using Account with only initial cash
|
||||
For `Position`:
|
||||
Using Account with a Position
|
||||
For `dict`:
|
||||
key "cash" means initial cash.
|
||||
key "stock1" means the first stock information with amount and price(optional).
|
||||
...
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
elif isinstance(account, Position):
|
||||
elif isinstance(account, dict):
|
||||
init_cash = account["cash"]
|
||||
del account["cash"]
|
||||
pos_kwargs = {
|
||||
"init_cash": account.position["cash"],
|
||||
"position_dict": account.position,
|
||||
"init_cash": init_cash,
|
||||
"position_dict": account,
|
||||
}
|
||||
else:
|
||||
raise ValueError("account must be in (int, float, Position)")
|
||||
|
||||
@@ -100,7 +100,6 @@ class Account:
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = None
|
||||
self.positions = {}
|
||||
|
||||
@@ -119,8 +118,11 @@ class Account:
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
# fill stock value
|
||||
self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
@@ -309,6 +311,7 @@ class Account:
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
if self.is_port_metr_enabled():
|
||||
# report is portfolio related analysis
|
||||
print(trade_start_time, trade_end_time)
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
@@ -394,9 +394,8 @@ class Exchange:
|
||||
if trade_account is not None and position is not None:
|
||||
raise ValueError("trade_account and position can only choose one")
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction)
|
||||
# NOTE: order will be changed in this function
|
||||
trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position, dealt_order_amount
|
||||
)
|
||||
if order.deal_amount > 1e-5:
|
||||
@@ -714,6 +713,63 @@ class Exchange:
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
|
||||
def _cal_trade_amount_by_cash_limit(self, now_trade_amount, trade_price, order, position):
|
||||
"""return the real order amount after cash limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
now_trade_amount : float
|
||||
trade_price : float
|
||||
order : Order
|
||||
position : Position
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
the real order amount after cash limit.
|
||||
"""
|
||||
cash = position.get_cash()
|
||||
trade_val = now_trade_amount * trade_price
|
||||
if order.direction == Order.SELL:
|
||||
if cash < trade_val * self.close_cost:
|
||||
# The money is not enough
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
return self.round_amount_by_trade_unit(cash / self.close_cost, order.factor)
|
||||
elif order.direction == Order.BUY:
|
||||
if cash < trade_val * (1 + self.open_cost):
|
||||
# The money is not enough
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
return self.round_amount_by_trade_unit(cash / (1 + self.open_cost) / trade_price, order.factor)
|
||||
|
||||
# The money is enough
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
|
||||
def _cal_trade_amount_by_stock_limit(self, now_trade_amount, order, position):
|
||||
"""return the real order amount after stock amount limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
now_trade_amount : float
|
||||
order : Order
|
||||
position : Position
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
the real order amount after stock amount limit.
|
||||
"""
|
||||
if order.direction == Order.SELL:
|
||||
current_amount = position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
if np.isclose(now_trade_amount, current_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
return now_trade_amount
|
||||
elif now_trade_amount > current_amount:
|
||||
return self.round_amount_by_trade_unit(current_amount, order.factor)
|
||||
else:
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
elif order.direction == Order.BUY:
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
"""
|
||||
Calculation of trade info
|
||||
@@ -731,16 +787,10 @@ class Exchange:
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
current_amount = (
|
||||
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
)
|
||||
if np.isclose(order.amount, current_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
order.deal_amount = order.amount
|
||||
elif order.amount > current_amount:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor)
|
||||
else:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
now_trade_amount = order.amount
|
||||
now_trade_amount = self._cal_trade_amount_by_stock_limit(now_trade_amount, order, position)
|
||||
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
|
||||
order.deal_amount = now_trade_amount
|
||||
else:
|
||||
# TODO: We don't know current position.
|
||||
# We choose to sell all
|
||||
@@ -752,17 +802,9 @@ class Exchange:
|
||||
elif order.direction == Order.BUY:
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.amount * trade_price
|
||||
if cash < trade_val * (1 + self.open_cost):
|
||||
# The money is not enough
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
cash / (1 + self.open_cost) / trade_price, order.factor
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
# THe money is enough
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
now_trade_amount = order.amount
|
||||
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
|
||||
order.deal_amount = now_trade_amount
|
||||
else:
|
||||
# Unknown amount of money. Just round the amount
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
@@ -773,7 +815,7 @@ class Exchange:
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
|
||||
return trade_val, trade_cost
|
||||
return trade_price, trade_val, trade_cost
|
||||
|
||||
def get_order_helper(self) -> OrderHelper:
|
||||
if not hasattr(self, "_order_helper"):
|
||||
|
||||
@@ -256,37 +256,33 @@ class Position(BasePosition):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
self.init_stock_info = position_dict.copy()
|
||||
self.position = self.init_stock_info.copy()
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _fill_stock_value(
|
||||
self, position_dict: dict, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30
|
||||
):
|
||||
# If the stock price information is missing, the account value will not be calculated temporarily
|
||||
try:
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
|
||||
"""fill the stock value by the close price of latest last_days from qlib.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position_dict : Dict[stock_id, {"amount": int, "price": float}]
|
||||
initial holding stocks.
|
||||
start_time :
|
||||
the start time of backtest.
|
||||
last_days : int, optional
|
||||
the days to get the latest close price, by default 30.
|
||||
|
||||
Return
|
||||
----------
|
||||
Dict[stock_id, {"amount": int, "price": float}]
|
||||
initial holding stocks with filled price.
|
||||
"""
|
||||
|
||||
stock_list = []
|
||||
for stock in position_dict:
|
||||
if ("price" not in position_dict[stock]) or (position_dict[stock]["price"] is None):
|
||||
for stock in self.init_stock_info:
|
||||
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
|
||||
stock_list.append(stock)
|
||||
|
||||
if len(stock_list) == 0:
|
||||
return position_dict
|
||||
return
|
||||
|
||||
start_time = pd.Timestamp(start_time)
|
||||
# note that start time is 2020-01-01 00:00:00 if raw start time is "2020-01-01"
|
||||
@@ -298,11 +294,13 @@ class Position(BasePosition):
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
|
||||
if len(price_dict) < len(stock_list):
|
||||
raise ValueError(f"there is no close price in qlib")
|
||||
lack_stock = set(stock_list) - set(price_dict)
|
||||
raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days")
|
||||
|
||||
for stock in stock_list:
|
||||
position_dict[stock]["price"] = price_dict[stock]
|
||||
return position_dict
|
||||
self.init_stock_info[stock]["price"] = price_dict[stock]
|
||||
self.position.update(self.init_stock_info)
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
"""
|
||||
|
||||
@@ -97,13 +97,13 @@ class Freq:
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
|
||||
cn_time = [
|
||||
CN_TIME = [
|
||||
datetime.strptime("9:30", "%H:%M"),
|
||||
datetime.strptime("11:30", "%H:%M"),
|
||||
datetime.strptime("13:00", "%H:%M"),
|
||||
datetime.strptime("15:00", "%H:%M"),
|
||||
]
|
||||
us_time = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")]
|
||||
US_TIME = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")]
|
||||
|
||||
|
||||
def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"):
|
||||
@@ -111,15 +111,15 @@ def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"):
|
||||
time_obj = datetime.strptime(time_obj, "%H:%M")
|
||||
|
||||
if region == "cn":
|
||||
if time_obj >= cn_time[0] and time_obj < cn_time[1]:
|
||||
return int((time_obj - cn_time[0]).total_seconds() / 60)
|
||||
elif time_obj >= cn_time[2] and time_obj < cn_time[3]:
|
||||
return int((time_obj - cn_time[2]).total_seconds() / 60) + 120
|
||||
if time_obj >= CN_TIME[0] and time_obj < CN_TIME[1]:
|
||||
return int((time_obj - CN_TIME[0]).total_seconds() / 60)
|
||||
elif time_obj >= CN_TIME[2] and time_obj < CN_TIME[3]:
|
||||
return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120
|
||||
else:
|
||||
raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
|
||||
elif region == "us":
|
||||
if time_obj >= us_time[0] and time_obj < us_time[1]:
|
||||
return int((time_obj - us_time[0]).total_seconds() / 60)
|
||||
if time_obj >= US_TIME[0] and time_obj < US_TIME[1]:
|
||||
return int((time_obj - US_TIME[0]).total_seconds() / 60)
|
||||
else:
|
||||
raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user