mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
add position test
This commit is contained in:
@@ -104,19 +104,28 @@ def create_account_instance(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
start_time
|
||||
start time of the benchmark
|
||||
end_time :
|
||||
end_time
|
||||
end time of the benchmark
|
||||
benchmark : str
|
||||
the benchmark for reporting
|
||||
account : Union[float, int, {"cash": float, "stock1": {"amount": int, "price"(optional): float}, "stock2": {"amount": int}}]
|
||||
account : Union[
|
||||
float,
|
||||
{
|
||||
"cash": float,
|
||||
"stock1": Union[
|
||||
int, # it is equal to {"amount": int}
|
||||
{"amount": int, "price"(optional): float},
|
||||
]
|
||||
},
|
||||
]
|
||||
information for describing how to creating the account
|
||||
For `float` or `int`:
|
||||
For `float`:
|
||||
Using Account with only initial cash
|
||||
For `dict`:
|
||||
key "cash" means initial cash.
|
||||
key "stock1" means the first stock information with amount and price(optional).
|
||||
key "stock1" means the information of first stock with amount and price(optional).
|
||||
...
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
|
||||
@@ -80,9 +80,15 @@ class Account:
|
||||
----------
|
||||
init_cash : float, optional
|
||||
initial cash, by default 1e9
|
||||
position_dict : Dict[stock_id, {"amount": int, "price"(optional): float}], optional
|
||||
initial stocks with amount and price,
|
||||
if there is no price key in the dict of stocks, it will be filled by latest close price from qlib.
|
||||
position_dict : Dict[
|
||||
stock_id,
|
||||
Union[
|
||||
int, # it is equal to {"amount": int}
|
||||
{"amount": int, "price"(optional): float},
|
||||
]
|
||||
]
|
||||
initial stocks with parameters amount and price,
|
||||
if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.
|
||||
by default {}.
|
||||
"""
|
||||
|
||||
@@ -122,6 +128,8 @@ class Account:
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
# fill stock value
|
||||
# The frequency of account may not align with the trading frequency.
|
||||
# This may result in obscure bugs when data quality is low.
|
||||
self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
@@ -186,7 +194,8 @@ class Account:
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
# sell stock
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
if getattr(self, "accum_info") is not None:
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
# update current position
|
||||
# for may sell all of stock_id
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
@@ -194,7 +203,8 @@ class Account:
|
||||
# buy stock
|
||||
# deal order, then update state
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
if getattr(self, "accum_info") is not None:
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_count(self):
|
||||
"""at the end of the trading bar, update holding bar, count of stock"""
|
||||
@@ -311,7 +321,6 @@ class Account:
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
if self.is_port_metr_enabled():
|
||||
# report is portfolio related analysis
|
||||
print(trade_start_time, trade_end_time)
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
@@ -713,12 +713,11 @@ class Exchange:
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
|
||||
def _cal_trade_amount_by_cash_limit(self, now_trade_amount, trade_price, order, position):
|
||||
def _get_max_amount_by_cash_limit(self, trade_price, order, position):
|
||||
"""return the real order amount after cash limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
now_trade_amount : float
|
||||
trade_price : float
|
||||
order : Order
|
||||
position : Position
|
||||
@@ -729,27 +728,24 @@ class Exchange:
|
||||
the real order amount after cash limit.
|
||||
"""
|
||||
cash = position.get_cash()
|
||||
trade_val = now_trade_amount * trade_price
|
||||
if order.direction == Order.SELL:
|
||||
if cash < trade_val * self.close_cost:
|
||||
# The money is not enough
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
return self.round_amount_by_trade_unit(cash / self.close_cost, order.factor)
|
||||
elif order.direction == Order.BUY:
|
||||
if cash < trade_val * (1 + self.open_cost):
|
||||
# The money is not enough
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
return self.round_amount_by_trade_unit(cash / (1 + self.open_cost) / trade_price, order.factor)
|
||||
max_trade_amount = 0
|
||||
if cash >= self.min_cost:
|
||||
if order.direction == Order.SELL:
|
||||
max_trade_amount = cash / self.close_cost / trade_price
|
||||
elif order.direction == Order.BUY:
|
||||
critical_amount = self.min_cost / (self.open_cost * trade_price)
|
||||
critical_price = critical_amount * trade_price + self.min_cost
|
||||
if cash >= critical_price:
|
||||
max_trade_amount = cash / (1 + self.open_cost) / trade_price
|
||||
else:
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
return max_trade_amount
|
||||
|
||||
# The money is enough
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
|
||||
def _cal_trade_amount_by_stock_limit(self, now_trade_amount, order, position):
|
||||
def _get_max_amount_by_stock_limit(self, order, position):
|
||||
"""return the real order amount after stock amount limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
now_trade_amount : float
|
||||
order : Order
|
||||
position : Position
|
||||
|
||||
@@ -760,15 +756,9 @@ class Exchange:
|
||||
"""
|
||||
if order.direction == Order.SELL:
|
||||
current_amount = position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
if np.isclose(now_trade_amount, current_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
return now_trade_amount
|
||||
elif now_trade_amount > current_amount:
|
||||
return self.round_amount_by_trade_unit(current_amount, order.factor)
|
||||
else:
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
return current_amount
|
||||
elif order.direction == Order.BUY:
|
||||
return self.round_amount_by_trade_unit(now_trade_amount, order.factor)
|
||||
return np.inf
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
"""
|
||||
@@ -779,18 +769,33 @@ class Exchange:
|
||||
:param order:
|
||||
:param position: Position
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
:return: trade_val, trade_cost
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
# get all limits amount
|
||||
# cash limit
|
||||
cash_max_amount = self._get_max_amount_by_cash_limit(trade_price, order, position)
|
||||
# held stock limit
|
||||
stock_max_amount = self._get_max_amount_by_stock_limit(order, position)
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
now_trade_amount = order.amount
|
||||
now_trade_amount = self._cal_trade_amount_by_stock_limit(now_trade_amount, order, position)
|
||||
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
|
||||
order.deal_amount = now_trade_amount
|
||||
if np.isclose(order.amount, stock_max_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
if stock_max_amount <= cash_max_amount:
|
||||
order.deal_amount = stock_max_amount
|
||||
else:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(cash_max_amount, order.factor)
|
||||
else:
|
||||
now_trade_amount = min(order.amount, stock_max_amount)
|
||||
if now_trade_amount > cash_max_amount:
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(now_trade_amount, cash_max_amount), order.factor
|
||||
)
|
||||
else:
|
||||
# TODO: We don't know current position.
|
||||
# We choose to sell all
|
||||
@@ -802,9 +807,9 @@ class Exchange:
|
||||
elif order.direction == Order.BUY:
|
||||
# buy
|
||||
if position is not None:
|
||||
now_trade_amount = order.amount
|
||||
now_trade_amount = self._cal_trade_amount_by_cash_limit(now_trade_amount, trade_price, order, position)
|
||||
order.deal_amount = now_trade_amount
|
||||
if order.amount > cash_max_amount:
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
order.deal_amount = self.round_amount_by_trade_unit(min(order.amount, cash_max_amount), order.factor)
|
||||
else:
|
||||
# Unknown amount of money. Just round the amount
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
|
||||
@@ -246,7 +246,13 @@ class Position(BasePosition):
|
||||
the start time of backtest. It's for filling the initial value of stocks.
|
||||
cash : float, optional
|
||||
initial cash in account, by default 0
|
||||
position_dict : Dict[stock_id, {"amount": int, "price"(optional): float}], optional
|
||||
position_dict : Dict[
|
||||
stock_id,
|
||||
Union[
|
||||
int, # it is equal to {"amount": int}
|
||||
{"amount": int, "price"(optional): float},
|
||||
]
|
||||
]
|
||||
initial stocks with parameters amount and price,
|
||||
if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.
|
||||
by default {}.
|
||||
@@ -256,8 +262,10 @@ class Position(BasePosition):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.init_stock_info = position_dict.copy()
|
||||
self.position = self.init_stock_info.copy()
|
||||
self.position = position_dict.copy()
|
||||
for stock in self.position:
|
||||
if isinstance(self.position[stock], int):
|
||||
self.position[stock] = {"amount": self.position[stock]}
|
||||
self.position["cash"] = cash
|
||||
|
||||
# If the stock price information is missing, the account value will not be calculated temporarily
|
||||
@@ -277,7 +285,9 @@ class Position(BasePosition):
|
||||
the days to get the latest close price, by default 30.
|
||||
"""
|
||||
stock_list = []
|
||||
for stock in self.init_stock_info:
|
||||
for stock in self.position:
|
||||
if not isinstance(self.position[stock], dict):
|
||||
continue
|
||||
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
|
||||
stock_list.append(stock)
|
||||
|
||||
@@ -298,8 +308,7 @@ class Position(BasePosition):
|
||||
raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days")
|
||||
|
||||
for stock in stock_list:
|
||||
self.init_stock_info[stock]["price"] = price_dict[stock]
|
||||
self.position.update(self.init_stock_info)
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
|
||||
@@ -27,6 +27,8 @@ class FileStrTest(TestAutoData):
|
||||
["20200102", self.TEST_INST, "1000", "sell"],
|
||||
["20200103", self.TEST_INST, "1000", "buy"],
|
||||
["20200106", self.TEST_INST, "1000", "sell"],
|
||||
["20200106", self.TEST_INST, "1000", "buy"],
|
||||
["20200106", self.TEST_INST, "949.7773413058803", "sell"],
|
||||
]
|
||||
return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"])
|
||||
|
||||
@@ -62,7 +64,7 @@ class FileStrTest(TestAutoData):
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": codes,
|
||||
"trade_unit": None,
|
||||
"trade_unit": 100,
|
||||
},
|
||||
# "pos_type": "InfPosition" # Position with infinitive position
|
||||
}
|
||||
|
||||
119
tests/backtest/test_init_position.py
Normal file
119
tests/backtest/test_init_position.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import qlib
|
||||
from qlib.backtest import backtest, order
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.backtest.order import TradeDecisionWO, TradeRangeByTime
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FileStrTest(TestAutoData):
|
||||
|
||||
TEST_INST = "SH600519"
|
||||
|
||||
def init_qlib(self):
|
||||
provider_uri_day = "/nfs_data1/stock_data/huaxia_1d_qlib"
|
||||
provider_uri_1min = "/nfs_data1/stock_data/huaxia_1min_qlib"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config, expression_cache=None, dataset_cache=None)
|
||||
|
||||
def test_file_str(self):
|
||||
freq = "1min"
|
||||
inst = ["SH600000", "SH600011"]
|
||||
start_time = "2020-01-01"
|
||||
end_time = "2020-01-15 15:00"
|
||||
|
||||
strategy_config = {
|
||||
"class": "RandomOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"trade_range": TradeRangeByTime("9:30", "15:00"),
|
||||
"sample_ratio": 1.0,
|
||||
"volume_ratio": 0.01,
|
||||
"market": inst,
|
||||
},
|
||||
}
|
||||
position_dict = {
|
||||
"cash": 100000000,
|
||||
"SH600000": {"amount": 100},
|
||||
"SH600011": {"amount": 101},
|
||||
}
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": position_dict,
|
||||
"benchmark": None, # benchmark is not required here for trading
|
||||
"exchange_kwargs": {
|
||||
"freq": freq,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": inst,
|
||||
},
|
||||
"pos_type": "Position", # Position with infinitive position
|
||||
}
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"generate_report": False,
|
||||
"verbose": False,
|
||||
# "verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_report": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
self.init_qlib()
|
||||
backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user