mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
supporting seperated buy and sell price
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import Union, TYPE_CHECKING
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from .account import Account
|
||||
|
||||
@@ -35,7 +35,7 @@ def get_exchange(
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
@@ -54,8 +54,15 @@ def get_exchange(
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
|
||||
"$" to the expression)
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
|
||||
@@ -69,13 +76,8 @@ def get_exchange(
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
exchange = Exchange(
|
||||
freq=freq,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -24,7 +24,7 @@ class Exchange:
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
deal_price=None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
subscribe_fields=[],
|
||||
limit_threshold=None,
|
||||
volume_threshold=None,
|
||||
@@ -40,7 +40,17 @@ class Exchange:
|
||||
:param start_time: closed start time for backtest
|
||||
:param end_time: closed end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
|
||||
:param deal_price: Union[str, Tuple[str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
|
||||
"$" to the expression)
|
||||
|
||||
:param subscribe_fields: list, subscribe fields
|
||||
:param limit_threshold: float, 0.1 for example, default None
|
||||
:param volume_threshold: float, 0.1 for example, default None
|
||||
@@ -86,10 +96,15 @@ class Exchange:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
if deal_price[0] != "$":
|
||||
self.deal_price = "$" + deal_price
|
||||
if isinstance(deal_price, str):
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
self.buy_price = self.sell_price = deal_price
|
||||
elif isinstance(deal_price, (tuple, list)):
|
||||
self.buy_price, self.sell_price = deal_price
|
||||
else:
|
||||
self.deal_price = deal_price
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
if isinstance(codes, str):
|
||||
codes = D.instruments(codes)
|
||||
self.codes = codes
|
||||
@@ -98,7 +113,7 @@ class Exchange:
|
||||
# $factor is for rounding to the trading unit
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"}
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
subscribe_fields = list(necessary_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(subscribe_fields))
|
||||
self.all_fields = all_fields
|
||||
@@ -118,8 +133,10 @@ class Exchange:
|
||||
)
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
if self.quote[self.deal_price].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(self.deal_price))
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.quote[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
|
||||
if self.quote["$factor"].isna().any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
@@ -148,9 +165,11 @@ class Exchange:
|
||||
# process extra_quote
|
||||
if "$close" not in self.extra_quote:
|
||||
raise ValueError("$close is necessray in extra_quote")
|
||||
if self.deal_price not in self.extra_quote.columns:
|
||||
self.extra_quote[self.deal_price] = self.extra_quote["$close"]
|
||||
self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if pstr not in self.extra_quote.columns:
|
||||
self.extra_quote[pstr] = self.extra_quote["$close"]
|
||||
self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
|
||||
if "$factor" not in self.extra_quote.columns:
|
||||
self.extra_quote["$factor"] = 1.0
|
||||
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
|
||||
@@ -241,7 +260,7 @@ class Exchange:
|
||||
if trade_account is not None and position is not None:
|
||||
raise ValueError("trade_account and position can only choose one")
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction)
|
||||
# NOTE: order will be changed in this function
|
||||
trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position
|
||||
@@ -266,12 +285,16 @@ class Exchange:
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last)
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir):
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
pstr = self.buy_price
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=ts_data_last)
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(
|
||||
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
|
||||
)
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
@@ -288,7 +311,9 @@ class Exchange:
|
||||
return None
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
):
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
@@ -297,7 +322,10 @@ class Exchange:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
cash : cash
|
||||
trade_date : trade date
|
||||
start_time : the start time point of the step
|
||||
end_time : the end time point of the step
|
||||
direction : the direction of the deal price for estimating the amount
|
||||
# NOTE: this function is used for calculating target position. So the default direction is buy
|
||||
"""
|
||||
|
||||
# calculate the total weight of tradable value
|
||||
@@ -324,7 +352,9 @@ class Exchange:
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
// self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
@@ -414,10 +444,16 @@ class Exchange:
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
|
||||
def calculate_amount_position_value(
|
||||
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
|
||||
):
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
direction : the direction of the deal price for estimating the amount
|
||||
# NOTE:
|
||||
This function is used for calculating current position value.
|
||||
So the default direction is sell.
|
||||
"""
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
@@ -426,7 +462,9 @@ class Exchange:
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
):
|
||||
value += (
|
||||
self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
)
|
||||
* amount_dict[stock_id]
|
||||
)
|
||||
return value
|
||||
@@ -466,7 +504,7 @@ class Exchange:
|
||||
:return: trade_val, trade_cost
|
||||
"""
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
|
||||
@@ -6,7 +6,7 @@ import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...backtest.order import Order, BaseTradeDecision, TradeDecisionWO
|
||||
from ...backtest.order import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
@@ -236,7 +236,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
# buy order
|
||||
buy_price = self.trade_exchange.get_deal_price(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
|
||||
)
|
||||
buy_amount = value / buy_price
|
||||
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
|
||||
|
||||
Reference in New Issue
Block a user