1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

supporting seperated buy and sell price

This commit is contained in:
Young
2021-07-06 06:28:14 +00:00
parent cb72857710
commit bdac9f4dda
3 changed files with 75 additions and 35 deletions

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Union, TYPE_CHECKING
from typing import List, Tuple, Union, TYPE_CHECKING
from .account import Account
@@ -35,7 +35,7 @@ def get_exchange(
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
deal_price: Union[str, Tuple[str], List[str]] = None,
):
"""get_exchange
@@ -54,8 +54,15 @@ def get_exchange(
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
deal_price: Union[str, Tuple[str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
<deal_price>, <buy_price> or <sell_price> := <price>
<price> := str
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
"$" to the expression)
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
@@ -69,13 +76,8 @@ def get_exchange(
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
exchange = Exchange(
freq=freq,

View File

@@ -4,7 +4,7 @@
import random
import logging
from typing import Union
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
@@ -24,7 +24,7 @@ class Exchange:
start_time=None,
end_time=None,
codes="all",
deal_price=None,
deal_price: Union[str, Tuple[str], List[str]] = None,
subscribe_fields=[],
limit_threshold=None,
volume_threshold=None,
@@ -40,7 +40,17 @@ class Exchange:
:param start_time: closed start time for backtest
:param end_time: closed end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'
:param deal_price: Union[str, Tuple[str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
<deal_price>, <buy_price> or <sell_price> := <price>
<price> := str
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
"$" to the expression)
:param subscribe_fields: list, subscribe fields
:param limit_threshold: float, 0.1 for example, default None
:param volume_threshold: float, 0.1 for example, default None
@@ -86,10 +96,15 @@ class Exchange:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
if deal_price[0] != "$":
self.deal_price = "$" + deal_price
if isinstance(deal_price, str):
if deal_price[0] != "$":
deal_price = "$" + deal_price
self.buy_price = self.sell_price = deal_price
elif isinstance(deal_price, (tuple, list)):
self.buy_price, self.sell_price = deal_price
else:
self.deal_price = deal_price
raise NotImplementedError(f"This type of input is not supported")
if isinstance(codes, str):
codes = D.instruments(codes)
self.codes = codes
@@ -98,7 +113,7 @@ class Exchange:
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"}
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))
self.all_fields = all_fields
@@ -118,8 +133,10 @@ class Exchange:
)
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
self.logger.warning("{} field data contains nan.".format(self.deal_price))
for attr in "buy_price", "sell_price":
pstr = getattr(self, attr) # price string
if self.quote[pstr].isna().any():
self.logger.warning("{} field data contains nan.".format(pstr))
if self.quote["$factor"].isna().any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
@@ -148,9 +165,11 @@ class Exchange:
# process extra_quote
if "$close" not in self.extra_quote:
raise ValueError("$close is necessray in extra_quote")
if self.deal_price not in self.extra_quote.columns:
self.extra_quote[self.deal_price] = self.extra_quote["$close"]
self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
for attr in "buy_price", "sell_price":
pstr = getattr(self, attr) # price string
if pstr not in self.extra_quote.columns:
self.extra_quote[pstr] = self.extra_quote["$close"]
self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
if "$factor" not in self.extra_quote.columns:
self.extra_quote["$factor"] = 1.0
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
@@ -241,7 +260,7 @@ class Exchange:
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, order.direction)
# NOTE: order will be changed in this function
trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current if trade_account else position
@@ -266,12 +285,16 @@ class Exchange:
def get_volume(self, stock_id, start_time, end_time):
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum")
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last)
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir):
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
pstr = self.buy_price
else:
raise NotImplementedError(f"This type of input is not supported")
deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=ts_data_last)
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
)
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time)
return deal_price
@@ -288,7 +311,9 @@ class Exchange:
return None
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
def generate_amount_position_from_weight_position(
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
):
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
@@ -297,7 +322,10 @@ class Exchange:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
cash : cash
trade_date : trade date
start_time : the start time point of the step
end_time : the end time point of the step
direction : the direction of the deal price for estimating the amount
# NOTE: this function is used for calculating target position. So the default direction is buy
"""
# calculate the total weight of tradable value
@@ -324,7 +352,9 @@ class Exchange:
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
// self.get_deal_price(
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
)
)
return amount_dict
@@ -414,10 +444,16 @@ class Exchange:
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
def calculate_amount_position_value(
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
):
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
direction : the direction of the deal price for estimating the amount
# NOTE:
This function is used for calculating current position value.
So the default direction is sell.
"""
value = 0
for stock_id in amount_dict:
@@ -426,7 +462,9 @@ class Exchange:
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
):
value += (
self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
self.get_deal_price(
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
)
* amount_dict[stock_id]
)
return value
@@ -466,7 +504,7 @@ class Exchange:
:return: trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
if order.direction == Order.SELL:
# sell
if position is not None:

View File

@@ -6,7 +6,7 @@ import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...backtest.order import Order, BaseTradeDecision, TradeDecisionWO
from ...backtest.order import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
from .order_generator import OrderGenWInteract
@@ -236,7 +236,7 @@ class TopkDropoutStrategy(ModelStrategy):
continue
# buy order
buy_price = self.trade_exchange.get_deal_price(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
)
buy_amount = value / buy_price
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)