1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00
Files
qlib/qlib/backtest/exchange.py
you-n-g be4646b4b7 Adjust rolling api (#1594)
* Intermediate version

* Fix yaml template & Successfully run rolling

* Be compatible with benchmark

* Get same results with previous linear model

* Black formatting

* Update black

* Update the placeholder mechanism

* Update CI

* Update CI

* Upgrade Black

* Fix CI and simplify code

* Fix CI

* Move the data processing caching mechanism into utils.

* Adjusting DDG-DA

* Organize import
2023-07-14 12:16:12 +08:00

958 lines
43 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from ..utils.index_data import IndexData
if TYPE_CHECKING:
from .account import Account
import random
import numpy as np
import pandas as pd
from qlib.backtest.position import BasePosition
from ..config import C
from ..constant import REG_CN, REG_TW
from ..data.data import D
from ..log import get_module_logger
from .decision import Order, OrderDir, OrderHelper
from .high_performance_ds import BaseQuote, NumpyQuote
class Exchange:
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
# After some processing, the data will later be maintained by `quote_cls` object for faster data retrieving.
# Some conventions for `quote_df`
# - $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is regarded as suspended.
# - $factor is for rounding to the trading unit;
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
quote_df: pd.DataFrame
def __init__(
self,
freq: str = "day",
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict, None] = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
impact_cost: float = 0.0,
extra_quote: pd.DataFrame = None,
quote_cls: Type[BaseQuote] = NumpyQuote,
**kwargs: Any,
) -> None:
"""__init__
:param freq: frequency of data
:param start_time: closed start time for backtest
:param end_time: closed end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: Union[str, Tuple[str, str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
<deal_price>, <buy_price> or <sell_price> := <price>
<price> := str
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
"$" to the expression)
:param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`.
It is useful when users want more fields to be queried
:param limit_threshold: Union[Tuple[str, str], float, None]
1) `None`: no limitation
2) float, 0.1 for example, default None
3) Tuple[str, str]: (<the expression for buying stock limitation>,
<the expression for sell stock limitation>)
`False` value indicates the stock is tradable
`True` value indicates the stock is limited and not tradable
:param volume_threshold: Union[
Dict[
"all": ("cum" or "current", limit_str),
"buy": ("cum" or "current", limit_str),
"sell":("cum" or "current", limit_str),
],
("cum" or "current", limit_str),
]
1) ("cum" or "current", limit_str) denotes a single volume limit.
- limit_str is qlib data expression which is allowed to define your own Operator.
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for
high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom
operator, you need to register it in qlib_init.
- "cum" means that this is a cumulative value over time, such as cumulative market
volume. So when it is used as a volume limit, it is necessary to subtract the dealt
amount.
- "current" means that this is a real-time value and will not accumulate over time,
so it can be directly used as a capacity limit.
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
2) "all" means the volume limits are both buying and selling.
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
Different volume limits will be aggregated with min(). If volume_threshold is only
("cum" or "current", limit_str) instead of a dict, the volume limits are for
both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
3) e.g. "volume_threshold": {
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
"buy": ("current", "$askV1"),
"sell": ("current", "$bidV1"),
}
:param open_cost: cost rate for open, default 0.0015
:param close_cost: cost rate for close, default 0.0025
:param trade_unit: trade unit, 100 for China A market.
None for disable trade unit.
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
distinguish `not set` and `disable trade_unit`
:param min_cost: min cost, default 5
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
The limit indicates that the etf is tradable on a specific day.
Necessary fields:
$close is for calculating the total value at end of each day.
Optional fields:
$volume is only necessary when we limit the trade amount or calculate
PA(vwap) indicator
$vwap is only necessary when we use the $vwap price as the deal price
$factor is for rounding to the trading unit
limit_sell will be set to False by default (False indicates we can sell
this target on this day).
limit_buy will be set to False by default (False indicates we can buy
this target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
self.freq = freq
self.start_time = start_time
self.end_time = end_time
self.trade_unit = kwargs.pop("trade_unit", C.trade_unit)
if len(kwargs) > 0:
raise ValueError(f"Get Unexpected arguments {kwargs}")
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
# we have some verbose information here. So logging is enabled
self.logger = get_module_logger("online operator")
# TODO: the quote, trade_dates, codes are not necessary.
# It is just for performance consideration.
self.limit_type = self._get_limit_type(limit_threshold)
if limit_threshold is None:
if C.region in [REG_CN, REG_TW]:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
if C.region in [REG_CN, REG_TW]:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
if isinstance(deal_price, str):
if deal_price[0] != "$":
deal_price = "$" + deal_price
self.buy_price = self.sell_price = deal_price
elif isinstance(deal_price, (tuple, list)):
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
else:
raise NotImplementedError(f"This type of input is not supported")
if isinstance(codes, str):
codes = D.instruments(codes)
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is regarded as suspended.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
#  get volume limit from kwargs
self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit(volume_threshold)
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple)
for exp in limit_threshold:
necessary_fields.add(exp)
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.impact_cost = impact_cost
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
self.get_quote_from_qlib()
# init quote by quote_df
self.quote_cls = quote_cls
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
def get_quote_from_qlib(self) -> None:
# get stock data from qlib
if len(self.codes) == 0:
self.codes = D.instruments()
self.quote_df = D.features(
self.codes,
self.all_fields,
self.start_time,
self.end_time,
freq=self.freq,
disk_cache=True,
)
self.quote_df.columns = self.all_fields
# check buy_price data and sell_price data
for attr in ("buy_price", "sell_price"):
pstr = getattr(self, attr) # price string
if self.quote_df[pstr].isna().any():
self.logger.warning("{} field data contains nan.".format(pstr))
# update trade_w_adj_price
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
if self.trade_unit is not None:
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")
else:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False
# update limit
self._update_limit(self.limit_threshold)
# concat extra_quote
if self.extra_quote is not None:
# process extra_quote
if "$close" not in self.extra_quote:
raise ValueError("$close is necessray in extra_quote")
for attr in "buy_price", "sell_price":
pstr = getattr(self, attr) # price string
if pstr not in self.extra_quote.columns:
self.extra_quote[pstr] = self.extra_quote["$close"]
self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
if "$factor" not in self.extra_quote.columns:
self.extra_quote["$factor"] = 1.0
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
if "limit_sell" not in self.extra_quote.columns:
self.extra_quote["limit_sell"] = False
self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
if "limit_buy" not in self.extra_quote.columns:
self.extra_quote["limit_buy"] = False
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
LT_NONE = "none" # none: there is no trading limitation
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
return self.LT_FLT
elif limit_threshold is None:
return self.LT_NONE
else:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
# $close may contain NaN, the nan indicates that the stock is not tradable at that timestamp
suspended = self.quote_df["$close"].isna()
# check limit_threshold
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = suspended
self.quote_df["limit_sell"] = suspended
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
# astype bool is necessary, because quote_df is an expression and could be float
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
self.quote_df["limit_sell"] = (
self.quote_df["$change"].le(-limit_threshold) | suspended
) # pylint: disable=E1130
@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
"""
preprocess the volume limit.
get the fields need to get from qlib.
get the volume limit list of buying and selling which is composed of all limits.
Parameters
----------
volume_threshold :
please refer to the doc of exchange.
Returns
-------
fields: set
the fields need to get from qlib.
buy_vol_limit: List[Tuple[str]]
all volume limits of buying.
sell_vol_limit: List[Tuple[str]]
all volume limits of selling.
Raises
------
ValueError
the format of volume_threshold is not supported.
"""
if volume_threshold is None:
return None, None, set()
fields = set()
buy_vol_limit = []
sell_vol_limit = []
if isinstance(volume_threshold, tuple):
volume_threshold = {"all": volume_threshold}
assert isinstance(volume_threshold, dict)
for key, vol_limit in volume_threshold.items():
assert isinstance(vol_limit, tuple)
fields.add(vol_limit[1])
if key in ("buy", "all"):
buy_vol_limit.append(vol_limit)
if key in ("sell", "all"):
sell_vol_limit.append(vol_limit)
return buy_vol_limit, sell_vol_limit, fields
def check_stock_limit(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int | None = None,
) -> bool:
"""
Parameters
----------
stock_id : str
start_time: pd.Timestamp
end_time: pd.Timestamp
direction : int, optional
trade direction, by default None
- if direction is None, check if tradable for buying and selling.
- if direction == Order.BUY, check the if tradable for buying
- if direction == Order.SELL, check the sell limit for selling.
Returns
-------
True: the trading of the stock is limited (maybe hit the highest/lowest price), hence the stock is not tradable
False: the trading of the stock is not limited, hence the stock may be tradable
"""
# NOTE:
# **all** is used when checking limitation.
# For example, the stock trading is limited in a day if every minute is limited in a day if every minute is limited.
if direction is None:
# The trading limitation is related to the trading direction
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return bool(buy_limit or sell_limit)
elif direction == Order.BUY:
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
elif direction == Order.SELL:
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
else:
raise ValueError(f"direction {direction} is not supported!")
def check_stock_suspended(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> bool:
"""if stock is suspended(hence not tradable), True will be returned"""
# is suspended
if stock_id in self.quote.get_all_stock():
# suspended stocks are represented by None $close stock
# The $close may contain NaN,
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
if close is None:
# if no close record exists
return True
elif isinstance(close, IndexData):
# **any** non-NaN $close represents trading opportunity may exist
# if all returned is nan, then the stock is suspended
return cast(bool, cast(IndexData, close).isna().all())
else:
# it is single value, make sure is not None
return np.isnan(close)
else:
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
return True
def is_stock_tradable(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int | None = None,
) -> bool:
# check if stock can be traded
return not (
self.check_stock_suspended(stock_id, start_time, end_time)
or self.check_stock_limit(stock_id, start_time, end_time, direction)
)
def check_order(self, order: Order) -> bool:
# check limit and suspended
return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)
def deal_order(
self,
order: Order,
trade_account: Account | None = None,
position: BasePosition | None = None,
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
Deal order when the actual transaction
the results section in `Order` will be changed.
:param order: Deal the order.
:param trade_account: Trade account to be updated after dealing the order.
:param position: position to be updated after dealing the order.
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
:return: trade_val, trade_cost, trade_price
"""
# check order first.
if not self.check_order(order):
order.deal_amount = 0.0
# using np.nan instead of None to make it more convenient to show the value in format string
self.logger.debug(f"Order failed due to trading limitation: {order}")
return 0.0, 0.0, np.nan
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
# NOTE: order will be changed in this function
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
order,
trade_account.current_position if trade_account else position,
dealt_order_amount,
)
if trade_val > 1e-5:
# If the order can only be deal 0 value. Nothing to be updated
# Otherwise, it will result in
# 1) some stock with 0 value in the position
# 2) `trade_unit` of trade_cost will be lost in user account
if trade_account:
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
elif position:
position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
return trade_val, trade_cost, trade_price
def get_quote_info(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
field: str,
method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
def get_close(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
def get_volume(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: Optional[str] = "sum",
) -> Union[None, int, float, bool, IndexData]:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
def get_deal_price(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: OrderDir,
method: Optional[str] = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
pstr = self.buy_price
else:
raise NotImplementedError(f"This type of input is not supported")
deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method)
return deal_price
def get_factor(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> Optional[float]:
"""
Returns
-------
Optional[float]:
`None`: if the stock is suspended `None` may be returned
`float`: return factor if the factor exists
"""
assert start_time is not None and end_time is not None, "the time range must be given"
if stock_id not in self.quote.get_all_stock():
return None
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
def generate_amount_position_from_weight_position(
self,
weight_position: dict,
cash: float,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: OrderDir = OrderDir.BUY,
) -> dict:
"""
Generates the target position according to the weight and the cash.
NOTE: All the cash will be assigned to the tradable stock.
Parameter:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
cash : cash
start_time : the start time point of the step
end_time : the end time point of the step
direction : the direction of the deal price for estimating the amount
# NOTE: this function is used for calculating target position. So the default direction is buy
"""
# calculate the total weight of tradable value
tradable_weight = 0.0
for stock_id, wp in weight_position.items():
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
# weight_position must be greater than 0 and less than 1
if wp < 0 or wp > 1:
raise ValueError(
"weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp),
)
tradable_weight += wp
if tradable_weight - 1.0 >= 1e-5:
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
direction=direction,
)
)
return amount_dict
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
:param target_amount:
:param factor:
:return real_deal_amount; Positive deal_amount indicates buying more stock.
"""
if current_amount == target_amount:
return 0
elif current_amount < target_amount:
deal_amount = target_amount - current_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return deal_amount
else:
if target_amount == 0:
return -current_amount
else:
deal_amount = current_amount - target_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return -deal_amount
def generate_order_for_target_amount_position(
self,
target_position: dict,
current_position: dict,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> List[Order]:
"""
Note: some future information is used in this function
Parameter:
target_position : dict { stock_id : amount }
current_position : dict { stock_id : amount}
trade_unit : trade_unit
down sample : for amount 321 and trade_unit 100, deal_amount is 300
deal order on trade_date
"""
# split buy and sell for further use
buy_order_list = []
sell_order_list = []
# three parts: kept stock_id, dropped stock_id, new stock_id
# handle kept stock_id
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest
# results of the same parameter are different;
# so here we sort stock_id, and then randomly shuffle the order of stock_id
# because the same random seed is used, the final stock_id order is fixed
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
random.seed(0)
random.shuffle(sorted_ids)
for stock_id in sorted_ids:
# Do not generate order for the non-tradable stocks
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
continue
target_amount = target_position.get(stock_id, 0)
current_amount = current_position.get(stock_id, 0)
factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
continue
if deal_amount > 0:
# buy stock
buy_order_list.append(
Order(
stock_id=stock_id,
amount=deal_amount,
direction=Order.BUY,
start_time=start_time,
end_time=end_time,
factor=factor,
),
)
else:
# sell stock
sell_order_list.append(
Order(
stock_id=stock_id,
amount=abs(deal_amount),
direction=Order.SELL,
start_time=start_time,
end_time=end_time,
factor=factor,
),
)
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(
self,
amount_dict: dict,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
only_tradable: bool = False,
direction: OrderDir = OrderDir.SELL,
) -> float:
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
direction : the direction of the deal price for estimating the amount
# NOTE:
This function is used for calculating current position value.
So the default direction is sell.
"""
value = 0
for stock_id in amount_dict:
if not only_tradable or (
not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)
and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)
):
value += (
self.get_deal_price(
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
direction=direction,
)
* amount_dict[stock_id]
)
return value
def _get_factor_or_raise_error(
self,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
"""Please refer to the docs of get_amount_of_trade_unit"""
if factor is None:
if stock_id is not None and start_time is not None and end_time is not None:
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
else:
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
assert factor is not None
return factor
def get_amount_of_trade_unit(
self,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Optional[float]:
"""
get the trade unit of amount based on **factor**
the factor can be given directly or calculated in given time range and stock id.
`factor` has higher priority than `stock_id`, `start_time` and `end_time`
Parameters
----------
factor : float
the adjusted factor
stock_id : str
the id of the stock
start_time :
the start time of trading range
end_time :
the end time of trading range
"""
if not self.trade_w_adj_price and self.trade_unit is not None:
factor = self._get_factor_or_raise_error(
factor=factor,
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
)
return self.trade_unit / factor
else:
return None
def round_amount_by_trade_unit(
self,
deal_amount: float,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
"""Parameter
Please refer to the docs of get_amount_of_trade_unit
deal_amount : float, adjusted amount
factor : float, adjusted factor
return : float, real amount
"""
if not self.trade_w_adj_price and self.trade_unit is not None:
# the minimal amount is 1. Add 0.1 for solving precision problem.
factor = self._get_factor_or_raise_error(
factor=factor,
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
)
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
"""parse the capacity limit string and return the actual amount of orders that can be executed.
NOTE:
this function will change the order.deal_amount **inplace**
- This will make the order info more accurate
Parameters
----------
order : Order
the order to be executed.
dealt_order_amount : dict
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
"""
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
if vol_limit is None:
return order.deal_amount
vol_limit_num: List[float] = []
for limit in vol_limit:
assert isinstance(limit, tuple)
if limit[0] == "current":
limit_value = self.quote.get_data(
order.stock_id,
order.start_time,
order.end_time,
field=limit[1],
method="sum",
)
vol_limit_num.append(cast(float, limit_value))
elif limit[0] == "cum":
limit_value = self.quote.get_data(
order.stock_id,
order.start_time,
order.end_time,
field=limit[1],
method="ts_data_last",
)
vol_limit_num.append(limit_value - dealt_order_amount[order.stock_id])
else:
raise ValueError(f"{limit[0]} is not supported")
vol_limit_min = min(vol_limit_num)
orig_deal_amount = order.deal_amount
order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)
if vol_limit_min < orig_deal_amount:
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
return None
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
cash : float
cost_ratio : float
Return
----------
float
the real order amount after cash limit for buying.
"""
max_trade_amount = 0.0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / cost_ratio + self.min_cost
if cash >= critical_price:
# the service fee is equal to cost_ratio * trade_amount
max_trade_amount = cash / (1 + cost_ratio) / trade_price
else:
# the service fee is equal to min_cost
max_trade_amount = (cash - self.min_cost) / trade_price
return max_trade_amount
def _calc_trade_info_by_order(
self,
order: Order,
position: Optional[BasePosition],
dealt_order_amount: dict,
) -> Tuple[float, float, float]:
"""
Calculation of trade info
**NOTE**: Order will be changed in this function
:param order:
:param position: Position
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
:return: trade_price, trade_val, trade_cost
"""
trade_price = cast(
float,
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
)
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first
# - It simulates that the order is rejected directly by the exchange due to large order
# Another choice is placing it after rounding the order
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
self._clip_amount_by_volume(order, dealt_order_amount)
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
trade_val = order.deal_amount * trade_price
if not total_trade_val or np.isnan(total_trade_val):
# TODO: assert trade_val == 0, f"trade_val != 0, total_trade_val: {total_trade_val}; order info: {order}"
adj_cost_ratio = self.impact_cost
else:
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
if order.direction == Order.SELL:
cost_ratio = self.close_cost + adj_cost_ratio
# sell
# if we don't know current position, we choose to sell all
# Otherwise, we clip the amount based on current position
if position is not None:
current_amount = (
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
)
if not np.isclose(order.deal_amount, current_amount):
# when not selling last stock. rounding is necessary
order.deal_amount = self.round_amount_by_trade_unit(
min(current_amount, order.deal_amount),
order.factor,
)
# in case of negative value of cash
if position.get_cash() + order.deal_amount * trade_price < max(
order.deal_amount * trade_price * cost_ratio,
self.min_cost,
):
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cash limitation: {order}")
elif order.direction == Order.BUY:
cost_ratio = self.open_cost + adj_cost_ratio
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.deal_amount * trade_price
if cash < max(trade_val * cost_ratio, self.min_cost):
# cash cannot cover cost
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
order.deal_amount = self.round_amount_by_trade_unit(
min(max_buy_amount, order.deal_amount),
order.factor,
)
self.logger.debug(f"Order clipped due to cash limitation: {order}")
else:
# The money is enough
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
# Unknown amount of money. Just round the amount
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
raise NotImplementedError("order direction {} error".format(order.direction))
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * cost_ratio, self.min_cost)
if trade_val <= 1e-5:
# if dealing is not successful, the trade_cost should be zero.
trade_cost = 0
return trade_price, trade_val, trade_cost
def get_order_helper(self) -> OrderHelper:
if not hasattr(self, "_order_helper"):
# cache to avoid recreate the same instance
self._order_helper = OrderHelper(self)
return self._order_helper