1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

update position and negative cash

This commit is contained in:
wangwenxi.handsome
2021-08-06 04:34:30 +00:00
parent 8e87950292
commit 74e1ee6921
5 changed files with 106 additions and 59 deletions

View File

@@ -96,7 +96,7 @@ def get_exchange(
def create_account_instance(
start_time, end_time, benchmark: str, account: Union[float, int, Position], pos_type: str = "Position"
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
) -> Account:
"""
# TODO: is very strange pass benchmark_config in the account(maybe for report)
@@ -110,19 +110,23 @@ def create_account_instance(
end time of the benchmark
benchmark : str
the benchmark for reporting
account : Union[float, int, Position]
account : Union[float, int, {"cash": float, "stock1": {"amount": int, "price"(optional): float}, "stock2": {"amount": int}}]
information for describing how to creating the account
For `float` or `int`:
Using Account with only initial cash
For `Position`:
Using Account with a Position
For `dict`:
key "cash" means initial cash.
key "stock1" means the first stock information with amount and price(optional).
...
"""
if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account}
elif isinstance(account, Position):
elif isinstance(account, dict):
init_cash = account["cash"]
del account["cash"]
pos_kwargs = {
"init_cash": account.position["cash"],
"position_dict": account.position,
"init_cash": init_cash,
"position_dict": account,
}
else:
raise ValueError("account must be in (int, float, Position)")

View File

@@ -100,7 +100,6 @@ class Account:
"module_path": "qlib.backtest.position",
}
)
self.accum_info = AccumulatedInfo()
self.report = None
self.positions = {}
@@ -119,8 +118,11 @@ class Account:
def reset_report(self, freq, benchmark_config):
# portfolio related metrics
if self.is_port_metr_enabled():
self.accum_info = AccumulatedInfo()
self.report = Report(freq, benchmark_config)
self.positions = {}
# fill stock value
self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq)
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()
@@ -309,6 +311,7 @@ class Account:
self.update_current(trade_start_time, trade_end_time, trade_exchange)
if self.is_port_metr_enabled():
# report is portfolio related analysis
print(trade_start_time, trade_end_time)
self.update_report(trade_start_time, trade_end_time)
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`

View File

@@ -394,9 +394,8 @@ class Exchange:
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction)
# NOTE: order will be changed in this function
trade_val, trade_cost = self._calc_trade_info_by_order(
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current if trade_account else position, dealt_order_amount
)
if order.deal_amount > 1e-5:
@@ -714,6 +713,63 @@ class Exchange:
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
def _cal_trade_amount_by_cash_limit(self, now_trade_amount, trade_price, order, position):
"""return the real order amount after cash limit.
Parameters
----------
now_trade_amount : float
trade_price : float
order : Order
position : Position
Return
----------
float
the real order amount after cash limit.
"""
cash = position.get_cash()
trade_val = now_trade_amount * trade_price
if order.direction == Order.SELL:
if cash < trade_val * self.close_cost:
# The money is not enough
self.logger.debug(f"Order clipped due to cash limitation: {order}")
return self.round_amount_by_trade_unit(cash / self.close_cost, order.factor)
elif order.direction == Order.BUY:
if cash < trade_val * (1 + self.open_cost):
# The money is not enough
self.logger.debug(f"Order clipped due to cash limitation: {order}")
return self.round_amount_by_trade_unit(cash / (1 + self.open_cost) / trade_price, order.factor)
# The money is enough
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
def _cal_trade_amount_by_stock_limit(self, now_trade_amount, order, position):
"""return the real order amount after stock amount limit.
Parameters
----------
now_trade_amount : float
order : Order
position : Position
Return
----------
float
the real order amount after stock amount limit.
"""
if order.direction == Order.SELL:
current_amount = position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
if np.isclose(now_trade_amount, current_amount):
# when selling last stock. The amount don't need rounding
return now_trade_amount
elif now_trade_amount > current_amount:
return self.round_amount_by_trade_unit(current_amount, order.factor)
else:
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
elif order.direction == Order.BUY:
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
"""
Calculation of trade info
@@ -731,16 +787,10 @@ class Exchange:
if order.direction == Order.SELL:
# sell
if position is not None:
current_amount = (
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
)
if np.isclose(order.amount, current_amount):
# when selling last stock. The amount don't need rounding
order.deal_amount = order.amount
elif order.amount > current_amount:
order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor)
else:
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
now_trade_amount = order.amount
now_trade_amount = self._cal_trade_amount_by_stock_limit(now_trade_amount, order, position)
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
order.deal_amount = now_trade_amount
else:
# TODO: We don't know current position.
# We choose to sell all
@@ -752,17 +802,9 @@ class Exchange:
elif order.direction == Order.BUY:
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.amount * trade_price
if cash < trade_val * (1 + self.open_cost):
# The money is not enough
order.deal_amount = self.round_amount_by_trade_unit(
cash / (1 + self.open_cost) / trade_price, order.factor
)
self.logger.debug(f"Order clipped due to cash limitation: {order}")
else:
# THe money is enough
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
now_trade_amount = order.amount
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
order.deal_amount = now_trade_amount
else:
# Unknown amount of money. Just round the amount
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
@@ -773,7 +815,7 @@ class Exchange:
else:
raise NotImplementedError("order type {} error".format(order.type))
return trade_val, trade_cost
return trade_price, trade_val, trade_cost
def get_order_helper(self) -> OrderHelper:
if not hasattr(self, "_order_helper"):

View File

@@ -256,37 +256,33 @@ class Position(BasePosition):
# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
self.init_stock_info = position_dict.copy()
self.position = self.init_stock_info.copy()
self.position["cash"] = cash
self.position["now_account_value"] = self.calculate_value()
def _fill_stock_value(
self, position_dict: dict, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30
):
# If the stock price information is missing, the account value will not be calculated temporarily
try:
self.position["now_account_value"] = self.calculate_value()
except KeyError:
pass
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
"""fill the stock value by the close price of latest last_days from qlib.
Parameters
----------
position_dict : Dict[stock_id, {"amount": int, "price": float}]
initial holding stocks.
start_time :
the start time of backtest.
last_days : int, optional
the days to get the latest close price, by default 30.
Return
----------
Dict[stock_id, {"amount": int, "price": float}]
initial holding stocks with filled price.
"""
stock_list = []
for stock in position_dict:
if ("price" not in position_dict[stock]) or (position_dict[stock]["price"] is None):
for stock in self.init_stock_info:
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
stock_list.append(stock)
if len(stock_list) == 0:
return position_dict
return
start_time = pd.Timestamp(start_time)
# note that start time is 2020-01-01 00:00:00 if raw start time is "2020-01-01"
@@ -298,11 +294,13 @@ class Position(BasePosition):
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
if len(price_dict) < len(stock_list):
raise ValueError(f"there is no close price in qlib")
lack_stock = set(stock_list) - set(price_dict)
raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days")
for stock in stock_list:
position_dict[stock]["price"] = price_dict[stock]
return position_dict
self.init_stock_info[stock]["price"] = price_dict[stock]
self.position.update(self.init_stock_info)
self.position["now_account_value"] = self.calculate_value()
def _init_stock(self, stock_id, amount, price=None):
"""

View File

@@ -97,13 +97,13 @@ class Freq:
return _count, _freq_format_dict[_freq]
cn_time = [
CN_TIME = [
datetime.strptime("9:30", "%H:%M"),
datetime.strptime("11:30", "%H:%M"),
datetime.strptime("13:00", "%H:%M"),
datetime.strptime("15:00", "%H:%M"),
]
us_time = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")]
US_TIME = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")]
def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"):
@@ -111,15 +111,15 @@ def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"):
time_obj = datetime.strptime(time_obj, "%H:%M")
if region == "cn":
if time_obj >= cn_time[0] and time_obj < cn_time[1]:
return int((time_obj - cn_time[0]).total_seconds() / 60)
elif time_obj >= cn_time[2] and time_obj < cn_time[3]:
return int((time_obj - cn_time[2]).total_seconds() / 60) + 120
if time_obj >= CN_TIME[0] and time_obj < CN_TIME[1]:
return int((time_obj - CN_TIME[0]).total_seconds() / 60)
elif time_obj >= CN_TIME[2] and time_obj < CN_TIME[3]:
return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120
else:
raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
elif region == "us":
if time_obj >= us_time[0] and time_obj < us_time[1]:
return int((time_obj - us_time[0]).total_seconds() / 60)
if time_obj >= US_TIME[0] and time_obj < US_TIME[1]:
return int((time_obj - US_TIME[0]).total_seconds() / 60)
else:
raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
else: