1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
This commit is contained in:
bxdd
2021-04-24 22:37:36 +08:00
parent b14efa1129
commit af0053eb17
29 changed files with 314 additions and 2247 deletions

View File

@@ -7,13 +7,9 @@ from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.backtest import backtest
from qlib.contrib.strategy import TopkDropoutStrategy
from qlib.contrib.backtest import backtest
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
@@ -67,9 +63,9 @@ if __name__ == "__main__":
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"train": ("2012-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"test": ("2017-01-01", "2018-01-31"),
},
},
},
@@ -79,41 +75,40 @@ if __name__ == "__main__":
dataset = init_instance_by_config(task["dataset"])
model.fit(dataset)
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time)
trade_start_time = "2017-01-31"
trade_end_time = "2018-01-31"
backtest_config={
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.dl_strategy",
"kwargs": {
"step_bar": "day",
"step_bar": "week",
"model": model,
"dataset": dataset,
"trade_exchange": trade_exchange,
"topk": 50,
"n_drop": 5,
},
},
"env":{
"class": "SplitEnv",
"module_path": "qlib.backtest.env",
"module_path": "qlib.contrib.backtest.env",
"kwargs": {
"step_bar": "day",
"step_bar": "week",
"sub_env": {
"class": "SimulatorEnv",
"module_path": "qlib.backtest.env",
"module_path": "qlib.contrib.backtest.env",
"kwargs": {
"step_bar": "1min",
"trade_exchange": trade_exchange,
"step_bar": "day",
}
},
"sub_strategy": {
"class": "SBBStrategyEMA",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"step_bar": "1min",
"step_bar": "day",
"freq": "day",
"instruments": "csi300",
}
}
}
@@ -121,4 +116,4 @@ if __name__ == "__main__":
}
backtest(**backtest_config, )
report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False)

View File

@@ -1,130 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .order import Order
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func
import copy
import numpy as np
import inspect
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
logger = get_module_logger("backtest caller")
def get_exchange(
pred,
exchange=None,
start_time=None,
end_time=None,
codes = "all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
close_cost : float
close transaction cost.
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
exchange = Exchange(
start_time=start_time,
end_time=end_time,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
trade_unit=trade_unit,
min_cost=min_cost,
)
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
root_instance.reset(trade_exchange=trade_exchange)
else:
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
root_instance.reset(trade_exchange=trade_exchange)
if hasattr(root_instance, "sub_env"):
setup_exchange(root_instance.sub_env, trade_exchange)
if hasattr(root_instance, "sub_strategy"):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
setup_exchange(trade_env, trade_exchange)
setup_exchange(trade_strategy, trade_exchange)
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
return report_dict

View File

@@ -1,170 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from .position import Position
from .report import Report
from .order import Order
"""
rtn & earning in the Account
rtn:
from order's view
1.change if any order is executed, sell order or buy order
2.change at the end of today, (today_clse - stock_price) * amount
earning
from value of current position
earning will be updated at the end of trade date
earning = today_value - pre_value
**is consider cost**
while earning is the difference of two position value, so it considers cost, it is the true return rate
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
"""
class Account:
def __init__(self, init_cash, last_trade_time=None):
self.init_vars(init_cash, last_trade_time)
def init_vars(self, init_cash, last_trade_time=None):
# init cash
self.init_cash = init_cash
self.current = Position(cash=init_cash)
self.positions = {}
self.rtn = 0
self.ct = 0
self.to = 0
self.val = 0
self.report = Report()
self.earning = 0
self.last_trade_time = last_trade_time
def get_positions(self):
return self.positions
def get_cash(self):
return self.current.position["cash"]
def update_state_from_order(self, order, trade_val, cost, trade_price):
# update turnover
self.to += trade_val
# update cost
self.ct += cost
# update return
# update self.rtn from order
trade_amount = trade_val / trade_price
if order.direction == Order.SELL: # 0 for sell
# when sell stock, get profit from price change
profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
self.rtn += profit # note here do not consider cost
elif order.direction == Order.BUY: # 1 for buy
# when buy stock, we get return for the rtn computing method
# profit in buy order is to make self.rtn is consistent with self.earning at the end of date
profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
self.rtn += profit
def update_order(self, order, trade_val, cost, trade_price):
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
# if stock is bought, there is no stock in current position, update current, then update account
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
trade_amount = trade_val / trade_price
if order.direction == Order.SELL:
# sell stock
self.update_state_from_order(order, trade_val, cost, trade_price)
# update current position
# for may sell all of stock_id
self.current.update_order(order, trade_val, cost, trade_price)
else:
# buy stock
# deal order, then update state
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
"""
start_time: pd.TimeStamp
end_time: pd.TimeStamp
quote: pd.DataFrame (code, date), collumns
when the end of trade date
- update rtn
- update price for each asset
- update value for this account
- update earning (2nd view of return )
- update holding day, count of stock
- update position hitory
- update report
:return: None
"""
# update price for stock in the position and the profit from changed_price
stock_list = self.current.get_stock_list()
profit = 0
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=bar_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()
# update value
self.val = self.current.calculate_value()
# update earning (2nd view of return)
# account_value - last_account_value
# for the first trade date, account_value - init_cash
# self.report.is_empty() to judge is_first_trade_date
# get last_account_value, now_account_value, now_stock_value
if self.report.is_empty():
last_account_value = self.init_cash
else:
last_account_value = self.report.get_latest_account_value()
now_account_value = self.current.calculate_value()
now_stock_value = self.current.calculate_stock_value()
self.earning = now_account_value - last_account_value
# update report for today
# judge whether the the trading is begin.
# and don't add init account state into report, due to we don't have excess return in those days.
self.report.update_report_record(
trade_time=trade_start_time,
account_value=now_account_value,
cash=self.current.position["cash"],
return_rate=(self.earning + self.ct) / last_account_value,
# here use earning to calculate return, position's view, earning consider cost, true return
# in order to make same definition with original backtest in evaluate.py
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=now_stock_value,
)
# set now_account_value to position
self.current.position["now_account_value"] = now_account_value
self.current.update_weight_all()
# update positions
# note use deepcopy
self.positions[trade_start_time] = copy.deepcopy(self.current)
# finish today's updation
# reset the daily variables
self.rtn = 0
self.ct = 0
self.to = 0
self.last_trade_time = (trade_start_time, trade_end_time)
def load_account(self, account_path):
report = Report()
position = Position()
last_trade_time = position.load_position(account_path / "position.xlsx")
report.load_report(account_path / "report.csv")
# assign values
self.init_vars(position.init_cash)
self.current = position
self.report = report
self.last_trade_time = last_trade_time
def save_account(self, account_path):
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
self.report.save_report(account_path / "report.csv")

View File

@@ -1,26 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from .account import Account
def backtest(trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account)
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_state = self.sub_env.get_init_state()
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
report_df = trade_account.report.generate_report_dataframe()
positions = trade_account.get_positions()
report_dict = {"report_df": report_df, "positions": positions}
return report_dict

View File

@@ -1,429 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import random
import logging
import numpy as np
import pandas as pd
from ..data import D
from ..utils import sample_feature
from .order import Order
from ..config import C, REG_CN
from ..log import get_module_logger
class Exchange:
def __init__(
self,
start_time=None,
end_time=None,
codes="all",
deal_price=None,
subscribe_fields=[],
limit_threshold=None,
open_cost=0.0015,
close_cost=0.0025,
trade_unit=None,
min_cost=5,
extra_quote=None,
):
"""__init__
:param start_time: start time for backtest
:param end_time: end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'
:param subscribe_fields: list, subscribe fields
:param limit_threshold: float, 0.1 for example, default None
:param open_cost: cost rate for open, default 0.0015
:param close_cost: cost rate for close, default 0.0025
:param trade_unit: trade unit, 100 for China A market
:param min_cost: min cost, default 5
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$factor', 'limit'].
The limit indicates that the etf is tradable on a specific day.
Necessary fields:
$close is for calculating the total value at end of each day.
Optional fields:
$vwap is only necessary when we use the $vwap price as the deal price
$factor is for rounding to the trading unit
limit will be set to False by default(False indicates we can buy this
target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
self.start_time = start_time
self.end_time = end_time
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
self.logger = get_module_logger("online operator", level=logging.INFO)
self.trade_unit = trade_unit
# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif abs(limit_threshold) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
if deal_price[0] != "$":
self.deal_price = "$" + deal_price
else:
self.deal_price = deal_price
if isinstance(codes, str):
codes = D.instruments(codes)
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.limit_threshold = limit_threshold
self.extra_quote = extra_quote
self.set_quote(codes, start_time, end_time)
def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
self.logger.warning("{} field data contains nan.".format(self.deal_price))
if self.quote["$factor"].isna().any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
else:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False
# update limit
# check limit_threshold
if self.limit_threshold is None:
self.quote["limit"] = False
else:
# set limit
self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold)
quote_df = self.quote
if self.extra_quote is not None:
# process extra_quote
if "$close" not in self.extra_quote:
raise ValueError("$close is necessray in extra_quote")
if self.deal_price not in self.extra_quote.columns:
self.extra_quote[self.deal_price] = self.extra_quote["$close"]
self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
if "$factor" not in self.extra_quote.columns:
self.extra_quote["$factor"] = 1.0
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
if "limit" not in self.extra_quote.columns:
self.extra_quote["limit"] = False
self.logger.warning("No limit set for extra_quote. All stock will be tradable.")
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
self.quote = quote_df
def _update_limit(self, buy_limit, sell_limit):
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
def check_stock_limit(self, stock_id, start_time, end_time):
"""Parameter
stock_id
trade_date
is limtited
"""
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0, 0]
def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
return sample_feature(self.quote, stock_id, start_time, end_time).empty is False
def is_stock_tradable(self, stock_id, start_time, end_time):
# check if stock can be traded
# same as check in check_order
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
return False
else:
return True
def check_order(self, order):
# check limit and suspended
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
order.stock_id, order.start_time, order.end_time
):
return False
else:
return True
def deal_order(self, order, trade_account=None, position=None):
"""
Deal order when the actual transaction
:param order: Deal the order.
:param trade_account: Trade account to be updated after dealing the order.
:param position: position to be updated after dealing the order.
:return: trade_val, trade_cost, trade_price
"""
# need to check order first
# TODO: check the order unit limit in the exchange!!!!
# The order limit is related to the adj factor and the cur_amount.
# factor = self.quote[(order.stock_id, order.trade_date)]['$factor']
# cur_amount = trade_account.current.get_stock_amount(order.stock_id)
if self.check_order(order) is False:
raise AttributeError("need to check order first")
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current if trade_account else position
)
# update account
if trade_val > 0:
# If the order can only be deal 0 trade_val. Nothing to be updated
# Otherwise, it will result some stock with 0 amount in the position
if trade_account:
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
elif position:
position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time)
def get_close(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last")
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last")
deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time)
return deal_price
def get_factor(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last")
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
Parameter:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
cash : cash
trade_date : trade date
"""
# calculate the total weight of tradable value
tradable_weight = 0.0
for stock_id in weight_position:
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
# weight_position must be greater than 0 and less than 1
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
raise ValueError(
"weight_position is {}, "
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
)
tradable_weight += weight_position[stock_id]
if tradable_weight - 1.0 >= 1e-5:
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
)
return amount_dict
def get_real_deal_amount(self, current_amount, target_amount, factor):
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
:param target_amount:
:param factor:
:return real_deal_amount; Positive deal_amount indicates buying more stock.
"""
if current_amount == target_amount:
return 0
elif current_amount < target_amount:
deal_amount = target_amount - current_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return deal_amount
else:
if target_amount == 0:
return -current_amount
else:
deal_amount = current_amount - target_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return -deal_amount
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
"""Parameter:
target_position : dict { stock_id : amount }
current_postion : dict { stock_id : amount}
trade_unit : trade_unit
down sample : for amount 321 and trade_unit 100, deal_amount is 300
deal order on trade_date
"""
# split buy and sell for further use
buy_order_list = []
sell_order_list = []
# three parts: kept stock_id, dropped stock_id, new stock_id
# handle kept stock_id
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
# so here we sort stock_id, and then randomly shuffle the order of stock_id
# because the same random seed is used, the final stock_id order is fixed
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
random.seed(0)
random.shuffle(sorted_ids)
for stock_id in sorted_ids:
# Do not generate order for the nontradable stocks
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
continue
target_amount = target_position.get(stock_id, 0)
current_amount = current_position.get(stock_id, 0)
factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
continue
elif deal_amount > 0:
# buy stock
buy_order_list.append(
Order(
stock_id=stock_id,
amount=deal_amount,
direction=Order.BUY,
start_time=start_time,
end_time=end_time,
factor=factor,
)
)
else:
# sell stock
sell_order_list.append(
Order(
stock_id=stock_id,
amount=abs(deal_amount),
direction=Order.SELL,
start_time=start_time,
end_time=end_time,
factor=factor,
)
)
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
"""
value = 0
for stock_id in amount_dict:
if (
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
):
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
return value
def round_amount_by_trade_unit(self, deal_amount, factor):
"""Parameter
deal_amount : float, adjusted amount
factor : float, adjusted factor
return : float, real amount
"""
if not self.trade_w_adj_price:
# the minimal amount is 1. Add 0.1 for solving precision problem.
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
def _calc_trade_info_by_order(self, order, position):
"""
Calculation of trade info
:param order:
:param position: Position
:return: trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
if order.direction == Order.SELL:
# sell
if position is not None:
if np.isclose(order.amount, position.get_stock_amount(order.stock_id)):
# when selling last stock. The amount don't need rounding
order.deal_amount = order.amount
else:
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
else:
# TODO: We don't know current position.
# We choose to sell all
order.deal_amount = order.amount
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * self.close_cost, self.min_cost)
elif order.direction == Order.BUY:
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.amount * trade_price
if cash < trade_val * (1 + self.open_cost):
# The money is not enough
order.deal_amount = self.round_amount_by_trade_unit(
cash / (1 + self.open_cost) / trade_price, order.factor
)
else:
# THe money is enough
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
else:
# Unknown amount of money. Just round the amount
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
trade_val = order.deal_amount * trade_price
trade_cost = trade_val * self.open_cost
else:
raise NotImplementedError("order type {} error".format(order.type))
return trade_val, trade_cost

View File

@@ -1,90 +0,0 @@
class HighFreqOrderNorm(Processor):
def __init__(self, fit_start_time, fit_end_time, feature_save_dir, price_dim=5, order_price_dim=2, volume_dim=1, order_volume_dim=8, day_length=240):
self.fit_start_time = fit_start_time
self.fit_end_time = fit_end_time
self.price_dim = price_dim
self.volume_dim = volume_dim
self.order_price_dim = order_price_dim
self.order_volume_dim = order_volume_dim
self.feature_save_dir = feature_save_dir
self.day_length = day_length
self.names = dict()
column_dim = self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim
fields = [("price", self.price_dim), ("order_price", self.order_price_dim), ("volume", self.volume_dim), ("order_volume", self.order_volume_dim)]
last_dim = 0
for field, field_dim in fields:
self.names[field] = list(range(last_dim, last_dim + field_dim)) + list((range(column_dim + last_dim, column_dim + last_dim + field_dim)))
last_dim += field_dim
@profile
def fit(self, df_features):
# fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
print("end")
if not os.path.exists(self.feature_save_dir):
os.makedirs(self.feature_save_dir)
for name, name_val in self.names.items():
print(name)
df_values = df_features.iloc(axis=1)[name_val].values
if name == "volume" or name == "order_volume":
df_values = np.log1p(df_values)
self.feature_med = np.nanmedian(df_values)
np.save(self.feature_save_dir + name + "_med.npy", self.feature_med)
df_values = df_values - self.feature_med
self.feature_std = np.nanmedian(np.absolute(df_values)) * 1.4826 + 1e-12
np.save(self.feature_save_dir + name + "_std.npy", self.feature_std)
df_values = df_values / self.feature_std
np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values))
np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values))
def __call__(self, df_features):
df_features.set_index("date", append=True, drop=True, inplace=True)
df_values = df_features.values
df_values_dict = dict()
for name, name_val in self.names.items():
self.feature_med = np.load(self.feature_save_dir + name + "_med.npy")
self.feature_std = np.load(self.feature_save_dir + name + "_std.npy")
self.feature_vmax = np.load(self.feature_save_dir + name + "_vmax.npy")
self.feature_vmin = np.load(self.feature_save_dir + name + "_vmin.npy")
df_values = df_features.iloc(axis=1)[name_val].values
if name == "volume" or name == "order_volume":
df_values[:] = np.log1p(df_values)
df_values[:] -= self.feature_med
df_values[:] /= self.feature_std
slice0 = df_values > 3.0
slice1 = df_values > 3.5
slice2 = df_values < -3.0
slice3 = df_values < -3.5
df_values[slice0] = (
3.0 + (df_values[slice0] - 3.0) / (self.feature_vmax - 3) * 0.5
)
df_values[slice1] = 3.5
df_values[slice2] = (
-3.0 - (df_values[slice2] + 3.0) / (self.feature_vmin + 3) * 0.5
)
df_values[slice3] = -3.5
df_values_dict[name] = df_values
idx = df_features.index.droplevel("datetime").drop_duplicates()
idx.set_names(["instrument", "datetime"], inplace=True)
# Reshape is specifically for adapting to RL high-freq executor
feat = df_values[:, list(range(self.price_dim)) + list(range(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim))
+ list(range((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim))
+ list(range((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim))
].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length)
feat_1 = df_values[:, list(np.arange(self.price_dim) + self.price_dim) + list(np.arange(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim) + self.order_price_dim)
+ list(np.arange((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim) + self.volume_dim)
+ list(np.arange((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim) + self.order_volume_dim)
].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length)
df_new_features = pd.DataFrame(
data=np.concatenate((feat, feat_1), axis=1),
index=idx,
columns=range(2 * (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length),
).sort_index()
return df_new_features

View File

@@ -1,132 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
import copy
import numpy as np
import inspect
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
logger = get_module_logger("backtest caller")
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def get_exchange(
exchange=None,
start_time=None,
end_time=None,
codes = "all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
close_cost : float
close transaction cost.
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
exchange = Exchange(
start_time=start_time,
end_time=end_time,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
trade_unit=trade_unit,
min_cost=min_cost,
)
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
trade_account = Account(init_cash=account)
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# temp_env = trade_env
# while True:
# if hasattr(temp_env, "trade_exchange"):
# temp_env.reset(trade_exchange=trade_exchange)
# if hasattr(temp_env, "sub_env"):
# temp_env = temp_env.sub_env
# else:
# break
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_state, _reset_info = self.sub_env.get_first_state()
trade_strategy.reset(**_reset_info)
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
return

View File

@@ -1,30 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class Order:
SELL = 0
BUY = 1
def __init__(self, stock_id, amount, start_time, end_time, direction, factor):
"""Parameter
direction : Order.SELL for sell; Order.BUY for buy
stock_id : str
amount : float
trade_date : pd.Timestamp
factor : float
presents the weight factor assigned in Exchange()
"""
# check direction
if direction not in {Order.SELL, Order.BUY}:
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
self.stock_id = stock_id
# amount of generated orders
self.amount = amount
# amount of successfully completed orders
self.deal_amount = 0
self.start_time = start_time
self.end_time = end_time
self.direction = direction
self.factor = factor

View File

@@ -1,217 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import copy
import pathlib
from .order import Order
"""
Position module
"""
"""
current state of position
a typical example is :{
<instrument_id>: {
'count': <how many days the security has been hold>,
'amount': <the amount of the security>,
'price': <the close price of security in the last trading day>,
'weight': <the security weight of total position value>,
},
}
"""
class Position:
"""Position"""
def __init__(self, cash=0, position_dict={}, today_account_value=0):
# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
def init_stock(self, stock_id, amount, price=None):
self.position[stock_id] = {}
self.position[stock_id]["count"] = 0 # update count in the end of this date
self.position[stock_id]["amount"] = amount
self.position[stock_id]["price"] = price
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
def buy_stock(self, stock_id, trade_val, cost, trade_price):
trade_amount = trade_val / trade_price
if stock_id not in self.position:
self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
else:
# exist, add amount
self.position[stock_id]["amount"] += trade_amount
self.position["cash"] -= trade_val + cost
def sell_stock(self, stock_id, trade_val, cost, trade_price):
trade_amount = trade_val / trade_price
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
self.del_stock(stock_id)
self.position["cash"] += trade_val - cost
def del_stock(self, stock_id):
del self.position[stock_id]
def update_order(self, order, trade_val, cost, trade_price):
# handle order, order is a order class, defined in exchange.py
if order.direction == Order.BUY:
# BUY
self.buy_stock(order.stock_id, trade_val, cost, trade_price)
elif order.direction == Order.SELL:
# SELL
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
else:
raise NotImplementedError("do not support order direction {}".format(order.direction))
def update_stock_price(self, stock_id, price):
self.position[stock_id]["price"] = price
def update_stock_count(self, stock_id, count):
self.position[stock_id]["count"] = count
def update_stock_weight(self, stock_id, weight):
self.position[stock_id]["weight"] = weight
def update_cash(self, cash):
self.position["cash"] = cash
def calculate_stock_value(self):
stock_list = self.get_stock_list()
value = 0
for stock_id in stock_list:
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
return value
def calculate_value(self):
value = self.calculate_stock_value()
value += self.position["cash"]
return value
def get_stock_list(self):
stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
return stock_list
def get_stock_price(self, code):
return self.position[code]["price"]
def get_stock_amount(self, code):
return self.position[code]["amount"]
def get_stock_count(self, code):
return self.position[code]["count"]
def get_stock_weight(self, code):
return self.position[code]["weight"]
def get_cash(self):
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
d[stock_code] = self.get_stock_amount(code=stock_code)
return d
def get_stock_weight_dict(self, only_stock=False):
"""get_stock_weight_dict
generate stock weight fict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade date
:param only_stock: If only_stock=True, the weight of each stock in total stock will be returned
If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
"""
if only_stock:
position_value = self.calculate_stock_value()
else:
position_value = self.calculate_value()
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
return d
def add_count_all(self):
stock_list = self.get_stock_list()
for code in stock_list:
self.position[code]["count"] += 1
def update_weight_all(self):
weight_dict = self.get_stock_weight_dict()
for stock_code, weight in weight_dict.items():
self.update_stock_weight(stock_code, weight)
def save_position(self, path, last_trade_time):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
del p["cash"]
del p["today_account_value"]
positions = pd.DataFrame.from_dict(p, orient="index")
with pd.ExcelWriter(path) as writer:
positions.to_excel(writer, sheet_name="position")
cash.to_excel(writer, sheet_name="info")
def load_position(self, path):
"""load position information from a file
should have format below
sheet "position"
columns: ['stock', 'count', 'amount', 'price', 'weight']
'count': <how many days the security has been hold>,
'amount': <the amount of the security>,
'price': <the close price of security in the last trading day>,
'weight': <the security weight of total position value>,
sheet "cash"
index: ['init_cash', 'cash', 'today_account_value']
'init_cash': <inital cash when account was created>,
'cash': <current cash in account>,
'today_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
"""
path = pathlib.Path(path)
positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0)
positions = positions.to_dict(orient="index")
init_cash = cash_record.loc["init_cash"].values[0]
cash = cash_record.loc["cash"].values[0]
today_account_value = cash_record.loc["today_account_value"].values[0]
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
# assign values
self.position = {}
self.init_cash = init_cash
self.position = positions
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
last_trade_start_time = None is pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
last_trade_end_time = None is pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
return last_trade_start_time, last_trade_end_time

View File

@@ -1,324 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from .position import Position
from ...data import D
from ...config import C
import datetime
from pathlib import Path
def get_benchmark_weight(
bench,
start_date=None,
end_date=None,
path=None,
):
"""get_benchmark_weight
get the stock weight distribution of the benchmark
:param bench:
:param start_date:
:param end_date:
:param path:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
Every column corresponds to a stock.
Every cell represents the strategy.
"""
if not path:
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
bench_weight_df = bench_weight_df[bench_weight_df["index"] == bench]
bench_weight_df["date"] = pd.to_datetime(bench_weight_df["date"])
if start_date is not None:
bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date]
if end_date is not None:
bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date]
bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0
return bench_stock_weight
def get_stock_weight_df(positions):
"""get_stock_weight_df
:param positions: Given a positions from backtest result.
:return: A weight distribution for the position
"""
stock_weight = []
index = []
for date in sorted(positions.keys()):
pos = positions[date]
if isinstance(pos, dict):
pos = Position(position_dict=pos)
index.append(date)
stock_weight.append(pos.get_stock_weight_dict(only_stock=True))
return pd.DataFrame(stock_weight, index=index)
def decompose_portofolio_weight(stock_weight_df, stock_group_df):
"""decompose_portofolio_weight
'''
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
every row corresponds to a day
every column corresponds to a stock.
Here is an example below.
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
date
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
....
:param stock_group_df: a pandas dataframe to describe the stock group.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the group id.
Here is a example by for stock_group_df for industry. The value is the industry code
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
...
:return: Two dict will be returned. The group_weight and the stock_weight_in_group.
The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock
"""
all_group = np.unique(stock_group_df.values.flatten())
all_group = all_group[~np.isnan(all_group)]
group_weight = {}
stock_weight_in_group = {}
for group_key in all_group:
group_mask = stock_group_df == group_key
group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1)
stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0)
return group_weight, stock_weight_in_group
def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
"""
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
every row corresponds to a day
every column corresponds to a stock.
Here is an example below.
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
date
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
2016-01-07 0.001555 0.001546 0.002772 0.001393 0.002904 NaN
2016-01-08 0.001564 0.001527 0.002791 0.001506 0.002948 NaN
2016-01-11 0.001597 0.001476 0.002738 0.001493 0.003043 NaN
....
:param stock_group_df: a pandas dataframe to describe the stock group.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the group id.
Here is a example by for stock_group_df for industry. The value is the industry code
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-07 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-08 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-11 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
...
:param stock_ret_df: a pandas dataframe to describe the stock return.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the return of the group.
Here is a example by for stock_ret_df.
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 0.007795 0.022070 0.099099 0.024707 0.009473 0.016216
2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936
2016-01-07 -0.001142 0.022544 0.100000 0.004225 0.000651 0.047226
2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408
2016-01-11 0.023460 0.004959 -0.034384 0.018663 0.014461 0.010962
...
:return: It will decompose the portofolio to the group weight and group return.
"""
all_group = np.unique(stock_group_df.values.flatten())
all_group = all_group[~np.isnan(all_group)]
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)
group_ret = {}
for group_key in stock_weight_in_group:
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
temp_stock_ret_df = stock_ret_df[
(stock_ret_df.index >= stock_weight_in_group_start_date)
& (stock_ret_df.index <= stock_weight_in_group_end_date)
]
group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
# If no weight is assigned, then the return of group will be np.nan
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan
group_weight_df = pd.DataFrame(group_weight)
group_ret_df = pd.DataFrame(group_ret)
return group_weight_df, group_ret_df
def get_daily_bin_group(bench_values, stock_values, group_n):
"""get_daily_bin_group
Group the values of the stocks of benchmark into several bins in a day.
Put the stocks into these bins.
:param bench_values: A series contains the value of stocks in benchmark.
The index is the stock code.
:param stock_values: A series contains the value of stocks of your portofolio
The index is the stock code.
:param group_n: Bins will be produced
:return: A series with the same size and index as the stock_value.
The value in the series is the group id of the bins.
The No.1 bin contains the biggest values.
"""
stock_group = stock_values.copy()
# get the bin split points based on the daily proportion of benchmark
split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1))
# Modify the biggest uppper bound and smallest lowerbound
split_points[0], split_points[-1] = -np.inf, np.inf
for i, (lb, up) in enumerate(zip(split_points, split_points[1:])):
stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i
return stock_group
def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None):
if group_method == "category":
# use the value of the benchmark as the category
return stock_group_field_df
elif group_method == "bins":
assert group_n is not None
# place the values into `group_n` fields.
# Each bin corresponds to a category.
new_stock_group_df = stock_group_field_df.copy().loc[
bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max()
]
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
bench_values = stock_group_field_df.loc[idx, row[row].index]
new_stock_group_df.loc[idx] = get_daily_bin_group(
bench_values, stock_group_field_df.loc[idx], group_n=group_n
)
return new_stock_group_df
def brinson_pa(
positions,
bench="SH000905",
group_field="industry",
group_method="category",
group_n=None,
deal_price="vwap",
):
"""brinson profit attribution
:param positions: The position produced by the backtest class
:param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used.
:param group_field: The field used to set the group for assets allocation.
`industry` and `market_value` is often used.
:param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation
`bin` will split the value into `group_n` bins and each bins represents a group
:param group_n: . Only used when group_method == 'bins'.
:return:
A dataframe with three columns: RAA(excess Return of Assets Allocation), RSS(excess Return of Stock Selectino), RTotal(Total excess Return)
Every row corresponds to a trading day, the value corresponds to the next return for this trading day
The middle info of brinson profit attribution
"""
# group_method will decide how to group the group_field.
dates = sorted(positions.keys())
start_date, end_date = min(dates), max(dates)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
# The attributes for allocation will not
if not group_field.startswith("$"):
group_field = "$" + group_field
if not deal_price.startswith("$"):
deal_price = "$" + deal_price
# FIXME: In current version. Some attributes(such as market_value) of some
# suspend stock is NAN. So we have to get more date to forward fill the NAN
shift_start_date = start_date - datetime.timedelta(days=250)
instruments = D.list_instruments(
D.instruments(market="all"),
start_time=shift_start_date,
end_time=end_date,
as_list=True,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
freq="day",
)
stock_df.columns = [group_field, "deal_price"]
stock_group_field = stock_df[group_field].unstack().T
# FIXME: some attributes of some suspend stock is NAN.
stock_group_field = stock_group_field.fillna(method="ffill")
stock_group_field = stock_group_field.loc[start_date:end_date]
stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)
deal_price_df = stock_df["deal_price"].unstack().T
deal_price_df = deal_price_df.fillna(method="ffill")
# NOTE:
# The return will be slightly different from the of the return in the report.
# Here the position are adjusted at the end of the trading day with close
stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1)
stock_ret = stock_ret.shift(-1).loc[start_date:end_date]
port_stock_weight_df = get_stock_weight_df(positions)
# decomposing the portofolio
port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret)
bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret)
# if the group return of the portofolio is NaN, replace it with the market
# value
mod_port_group_ret_df = port_group_ret_df.copy()
mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df
Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1)
Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1)
Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1)
Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1)
return (
pd.DataFrame(
{
"RAA": Q2 - Q1, # The excess profit from the assets allocation
"RSS": Q3 - Q1, # The excess profit from the stocks selection
# The excess profit from the interaction of assets allocation and stocks selection
"RIN": Q4 - Q3 - Q2 + Q1,
"RTotal": Q4 - Q1, # The totoal excess profit
}
),
{
"port_group_ret": port_group_ret_df,
"port_group_weight": port_group_weight_df,
"bench_group_ret": bench_group_ret_df,
"bench_group_weight": bench_group_weight_df,
"stock_group": stock_group,
"bench_stock_weight": bench_stock_weight,
"port_stock_weight": port_stock_weight_df,
"stock_ret": stock_ret,
},
)

View File

@@ -1,106 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import OrderedDict
import pandas as pd
import pathlib
class Report:
# daily report of the account
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
# update report
def __init__(self):
self.init_vars()
def init_vars(self):
self.accounts = OrderedDict() # account postion value for each trade date
self.returns = OrderedDict() # daily return rate for each trade date
self.turnovers = OrderedDict() # turnover for each trade date
self.costs = OrderedDict() # trade cost for each trade date
self.values = OrderedDict() # value for each trade date
self.cashes = OrderedDict()
self.latest_report_time = None # pd.TimeStamp
def is_empty(self):
return len(self.accounts) == 0
def get_latest_date(self):
return self.latest_report_time
def get_latest_account_value(self):
return self.accounts[self.latest_report_time]
def update_report_record(
self,
trade_time=None,
account_value=None,
cash=None,
return_rate=None,
turnover_rate=None,
cost_rate=None,
stock_value=None,
):
# check data
if None in [
trade_time,
account_value,
cash,
return_rate,
turnover_rate,
cost_rate,
stock_value,
]:
raise ValueError(
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
)
# update report data
self.accounts[trade_time] = account_value
self.returns[trade_time] = return_rate
self.turnovers[trade_time] = turnover_rate
self.costs[trade_time] = cost_rate
self.values[trade_time] = stock_value
self.cashes[trade_time] = cash
# update latest_report_date
self.latest_report_time = trade_time
# finish daily report update
def generate_report_dataframe(self):
report = pd.DataFrame()
report["account"] = pd.Series(self.accounts)
report["return"] = pd.Series(self.returns)
report["turnover"] = pd.Series(self.turnovers)
report["cost"] = pd.Series(self.costs)
report["value"] = pd.Series(self.values)
report["cash"] = pd.Series(self.cashes)
report.index.name = "trade_time"
return report
def save_report(self, path):
r = self.generate_report_dataframe()
r.to_csv(path)
def load_report(self, path):
"""load report from a file
should have format like
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
:param
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
r = pd.read_csv(open(path, "rb"), index_col=0)
r.index = pd.DatetimeIndex(r.index)
index = r.index
self.init_vars()
for trade_time in index:
self.update_report_record(
trade_time=trade_time,
account_value=r.loc[trade_time]["account"],
cash=r.loc[trade_time]["cash"],
return_rate=r.loc[trade_time]["return"],
turnover_rate=r.loc[trade_time]["turnover"],
cost_rate=r.loc[trade_time]["cost"],
stock_value=r.loc[trade_time]["value"],
)

View File

@@ -2,12 +2,12 @@
# Licensed under the MIT License.
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
from .backtest import backtest as backtest_func
import copy
import numpy as np
import inspect
from ...utils import init_instance_by_config
@@ -17,86 +17,11 @@ from ...config import C
logger = get_module_logger("backtest caller")
def get_strategy(
strategy=None,
topk=50,
margin=0.5,
n_drop=5,
risk_degree=0.95,
str_type="dropout",
adjust_dates=None,
):
"""get_strategy
There will be 3 ways to return a stratgy. Please follow the code.
Parameters
----------
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
:class: Strategy
an initialized strategy object
"""
# There will be 3 ways to return a strategy.
if strategy is None:
# 1) create strategy with param `strategy`
str_cls_dict = {
"amount": "TopkAmountStrategy",
"weight": "TopkWeightStrategy",
"dropout": "TopkDropoutStrategy",
}
logger.info("Create new strategy ")
from .. import strategy as strategy_pool
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
strategy = str_cls(
topk=topk,
buffer_margin=margin,
n_drop=n_drop,
risk_degree=risk_degree,
adjust_dates=adjust_dates,
)
elif isinstance(strategy, (dict, str)):
# 2) create strategy with init_instance_by_config
logger.info("Create new strategy ")
strategy = init_instance_by_config(strategy)
from ..strategy.strategy import BaseStrategy
# else: nothing happens. 3) Use the strategy directly
if not isinstance(strategy, BaseStrategy):
raise TypeError("Strategy not supported")
return strategy
def get_exchange(
pred,
exchange=None,
start_time=None,
end_time=None,
codes = "all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
@@ -104,7 +29,6 @@ def get_exchange(
trade_unit=None,
limit_threshold=None,
deal_price=None,
extract_codes=False,
shift=1,
):
"""get_exchange
@@ -128,9 +52,6 @@ def get_exchange(
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
Returns
-------
@@ -149,176 +70,61 @@ def get_exchange(
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
exchange = Exchange(
trade_dates=dates,
start_time=start_time,
end_time=end_time,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
min_cost=min_cost,
)
return exchange
return exchange
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"])
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def get_executor(
executor=None,
trade_exchange=None,
verbose=True,
):
"""get_executor
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
root_instance.reset(trade_exchange=trade_exchange)
else:
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
root_instance.reset(trade_exchange=trade_exchange)
if hasattr(root_instance, "sub_env"):
setup_exchange(root_instance.sub_env, trade_exchange)
if hasattr(root_instance, "sub_strategy"):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
There will be 3 ways to return a executor. Please follow the code.
Parameters
----------
executor : BaseExecutor
executor used in backtest.
trade_exchange : Exchange
exchange used in executor
verbose : bool
whether to print log.
Returns
-------
:class: BaseExecutor
an initialized BaseExecutor object
"""
# There will be 3 ways to return a executor.
if executor is None:
# 1) create executor with param `executor`
logger.info("Create new executor ")
from ..online.executor import SimulatorExecutor
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
elif isinstance(executor, (dict, str)):
# 2) create executor with config
logger.info("Create new executor ")
executor = init_instance_by_config(executor)
from ..online.executor import BaseExecutor
# 3) Use the executor directly
if not isinstance(executor, BaseExecutor):
raise TypeError("Executor not supported")
return executor
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
Parameters
----------
- **backtest workflow related or commmon arguments**
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column.
account : float
init account value.
shift : int
whether to shift prediction by one day.
benchmark : str
benchmark code, default is SH000905 CSI 500.
verbose : bool
whether to print log.
return_order : bool
whether to return order list
- **strategy related arguments**
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
"""
# check strategy:
spec = inspect.getfullargspec(get_strategy)
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
strategy = get_strategy(**str_args)
# init exchange:
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
exchange_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(**exchange_args)
# init executor:
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
setup_exchange(trade_env, trade_exchange)
setup_exchange(trade_strategy, trade_exchange)
# run backtest
report_dict = backtest_func(
pred=pred,
strategy=strategy,
executor=executor,
trade_exchange=trade_exchange,
shift=shift,
verbose=verbose,
account=account,
benchmark=benchmark,
return_order=return_order,
)
# for compatibility of the old API. return the dict positions
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
positions = report_dict.get("positions")
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
return report_dict

View File

@@ -26,10 +26,10 @@ rtn & earning in the Account
class Account:
def __init__(self, init_cash, last_trade_date=None):
self.init_vars(init_cash, last_trade_date)
def __init__(self, init_cash, last_trade_time=None):
self.init_vars(init_cash, last_trade_time)
def init_vars(self, init_cash, last_trade_date=None):
def init_vars(self, init_cash, last_trade_time=None):
# init cash
self.init_cash = init_cash
self.current = Position(cash=init_cash)
@@ -40,7 +40,7 @@ class Account:
self.val = 0
self.report = Report()
self.earning = 0
self.last_trade_date = last_trade_date
self.last_trade_time = last_trade_time
def get_positions(self):
return self.positions
@@ -83,9 +83,10 @@ class Account:
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_daily_end(self, today, trader):
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
"""
today: pd.TimeStamp
start_time: pd.TimeStamp
end_time: pd.TimeStamp
quote: pd.DataFrame (code, date), collumns
when the end of trade date
- update rtn
@@ -102,11 +103,11 @@ class Account:
profit = 0
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trader.check_stock_suspended(code, today):
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=bar_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()
@@ -116,54 +117,54 @@ class Account:
# account_value - last_account_value
# for the first trade date, account_value - init_cash
# self.report.is_empty() to judge is_first_trade_date
# get last_account_value, today_account_value, today_stock_value
# get last_account_value, now_account_value, now_stock_value
if self.report.is_empty():
last_account_value = self.init_cash
else:
last_account_value = self.report.get_latest_account_value()
today_account_value = self.current.calculate_value()
today_stock_value = self.current.calculate_stock_value()
self.earning = today_account_value - last_account_value
now_account_value = self.current.calculate_value()
now_stock_value = self.current.calculate_stock_value()
self.earning = now_account_value - last_account_value
# update report for today
# judge whether the the trading is begin.
# and don't add init account state into report, due to we don't have excess return in those days.
self.report.update_report_record(
trade_date=today,
account_value=today_account_value,
trade_time=trade_start_time,
account_value=now_account_value,
cash=self.current.position["cash"],
return_rate=(self.earning + self.ct) / last_account_value,
# here use earning to calculate return, position's view, earning consider cost, true return
# in order to make same definition with original backtest in evaluate.py
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=today_stock_value,
stock_value=now_stock_value,
)
# set today_account_value to position
self.current.position["today_account_value"] = today_account_value
# set now_account_value to position
self.current.position["now_account_value"] = now_account_value
self.current.update_weight_all()
# update positions
# note use deepcopy
self.positions[today] = copy.deepcopy(self.current)
self.positions[trade_start_time] = copy.deepcopy(self.current)
# finish today's updation
# reset the daily variables
self.rtn = 0
self.ct = 0
self.to = 0
self.last_trade_date = today
self.last_trade_time = (trade_start_time, trade_end_time)
def load_account(self, account_path):
report = Report()
position = Position()
last_trade_date = position.load_position(account_path / "position.xlsx")
last_trade_time = position.load_position(account_path / "position.xlsx")
report.load_report(account_path / "report.csv")
# assign values
self.init_vars(position.init_cash)
self.current = position
self.report = report
self.last_trade_date = last_trade_date if last_trade_date else None
self.last_trade_time = last_trade_time
def save_account(self, account_path):
self.current.save_position(account_path / "position.xlsx", self.last_trade_date)
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
self.report.save_report(account_path / "report.csv")

View File

@@ -4,140 +4,24 @@
import numpy as np
import pandas as pd
from ...utils import get_date_by_shift, get_date_range
from ...data import D
from .account import Account
from ...config import C
from ...log import get_module_logger
from ...data.dataset.utils import get_level_index
LOG = get_module_logger("backtest")
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
strategy : Strategy()
strategy part for backtest
trade_exchange : Exchange()
exchage for backtest
shift : int
whether to shift prediction by one day
verbose : bool
whether to print log
account : float
init account value
benchmark : str/list/pd.Series
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
# Convert format if the input format is not expected
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
if isinstance(benchmark, pd.Series):
bench = benchmark
else:
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
_temp_result = D.features(
_codes,
["$close/Ref($close,1)-1"],
predict_dates[0],
get_date_by_shift(predict_dates[-1], shift=shift),
disk_cache=1,
)
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
if return_order:
multi_order_list = []
# trading apart
for pred_date, trade_date in zip(predict_dates, trade_dates):
# for loop predict date and trading date
# print
if verbose:
LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
# 1. Load the score_series at pred_date
try:
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)
except KeyError:
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
score_series = None
if score_series is not None and score_series.count() > 0: # in case of the scores are all None
# 2. Update your strategy (and model)
strategy.update(score_series, pred_date, trade_date)
# 3. Generate order list
order_list = strategy.generate_order_list(
score_series=score_series,
current=trade_account.current,
trade_exchange=trade_exchange,
pred_date=pred_date,
trade_date=trade_date,
)
else:
order_list = []
if return_order:
multi_order_list.append((trade_account, order_list, trade_date))
# 4. Get result after executing order list
# NOTE: The following operation will modify order.amount.
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
trade_info = executor.execute(trade_account, order_list, trade_date)
# 5. Update account information according to transaction
update_account(trade_account, trade_info, trade_exchange, trade_date)
# generate backtest report
trade_state = trade_env.get_init_state()
while not trade_env.finished():
_order_list = trade_strategy.generate_order_list(**trade_state)
print("_order_list", _order_list)
trade_state, trade_info = trade_env.execute(_order_list)
report_df = trade_account.report.generate_report_dataframe()
report_df["bench"] = bench
positions = trade_account.get_positions()
report_dict = {"report_df": report_df, "positions": positions}
if return_order:
report_dict.update({"order_list": multi_order_list})
return report_dict
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
Parameters
----------
trade_account : Account()
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
trade_exchange : Exchange()
used to get the $close_price at trade_date to update account
trade_date : pd.Timestamp
"""
# update account
for [order, trade_val, trade_cost, trade_price] in trade_info:
if order.deal_amount == 0:
continue
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
# at the end of trade date, update the account based the $close_price of stocks.
trade_account.update_daily_end(today=trade_date, trader=trade_exchange)

View File

@@ -5,13 +5,13 @@ import json
import copy
import warnings
import pathlib
import numpy as np
import pandas as pd
from loguru import Logger
from ...data import D, Cal
from ...utils import get_date_in_file_name
from ...utils import get_pre_trading_date
from ..backtest.order import Order
from ..utils import init_instance_by_config
from ...data.data import Cal
from ...utils import get_sample_freq_calendar
from .order import Order
class TradeCalendarBase:
def _reset_trade_calendar(self, start_time, end_time):
@@ -20,10 +20,10 @@ class TradeCalendarBase:
if end_time:
self.end_time = pd.Timestamp(end_time)
if self.start_time and self.end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar)
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
self.calendar = _calendar
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
_trade_calendar = self.calendar[_start_index, _end_index + 1]
_trade_calendar = self.calendar[_start_index: _end_index + 1]
if _start_time != self.start_time:
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
self.start_index = _start_index - 1
@@ -40,7 +40,7 @@ class TradeCalendarBase:
trade_index = trade_index - shift
if 0 < trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1)
trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1)
return trade_start_time, trade_end_time
elif trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
@@ -68,7 +68,7 @@ class BaseEnv(TradeCalendarBase):
end_time=None,
trade_account=None,
verbose=False,
**kwargs
**kwargs,
):
self.step_bar = step_bar
self.verbose = verbose
@@ -76,24 +76,24 @@ class BaseEnv(TradeCalendarBase):
def _get_position(self):
return self.trade_account.current
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
self.trade_account = trade_account
if trade_account:
self.trade_account = trade_account
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def get_first_state(self):
def get_init_state(self):
init_state = {"current": self._get_position()}
return init_state
def execute(self, order_list, **kwargs):
def execute(self, order_list=None, **kwargs):
self.trade_index = self.trade_index + 1
def finished(self):
@@ -122,13 +122,13 @@ class SplitEnv(BaseEnv):
#if self.track:
# yield action
#episode_reward = 0
super(SimulatorEnv, self).execute(**kwargs)
super(SplitEnv, self).execute(**kwargs)
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
trade_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
_order_list = self.sub_strategy.generate_order_list(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
#episode_reward += sub_reward
_obs = {"current": self._get_position()}
@@ -149,11 +149,12 @@ class SimulatorEnv(BaseEnv):
verbose=False,
**kwargs,
):
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose)
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs)
def reset(trade_exchange=None, **kwargs):
def reset(self, trade_exchange=None, **kwargs):
super(SimulatorEnv, self).reset(**kwargs)
self.trade_exchange=trade_exchange
if trade_exchange:
self.trade_exchange=trade_exchange
def execute(self, order_list, **kwargs):
"""
@@ -162,7 +163,7 @@ class SimulatorEnv(BaseEnv):
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
super(SimulatorEnv, self).execute(**kwargs)
ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
trade_info = []
for order in order_list:
if self.trade_exchange.check_order(order) is True:

View File

@@ -8,16 +8,19 @@ import logging
import numpy as np
import pandas as pd
from ...data import D
from .order import Order
from ...data.data import D
from ...config import C, REG_CN
from ...utils import sample_feature
from ...log import get_module_logger
from .order import Order
class Exchange:
def __init__(
self,
trade_dates=None,
start_time=None,
end_time=None,
codes="all",
deal_price=None,
subscribe_fields=[],
@@ -30,7 +33,8 @@ class Exchange:
):
"""__init__
:param trade_dates: list of pd.Timestamp
:param start_time: start time for backtest
:param end_time: end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'
:param subscribe_fields: list, subscribe fields
@@ -51,6 +55,8 @@ class Exchange:
target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
self.start_time = start_time
self.end_time = end_time
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
@@ -91,21 +97,15 @@ class Exchange:
self.close_cost = close_cost
self.min_cost = min_cost
self.limit_threshold = limit_threshold
# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
if trade_dates is not None and len(trade_dates):
start_date, end_date = trade_dates[0], trade_dates[-1]
else:
self.logger.warning("trade_dates have not been assigned, all dates will be loaded")
start_date, end_date = None, None
self.extra_quote = extra_quote
self.set_quote(codes, start_date, end_date)
self.set_quote(codes, start_time, end_time)
def set_quote(self, codes, start_date, end_date):
def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()
self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"])
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
@@ -146,35 +146,37 @@ class Exchange:
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
self.quote = quote_df.to_dict("index")
self.quote = quote_df
def _update_limit(self, buy_limit, sell_limit):
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
def check_stock_limit(self, stock_id, trade_date):
def check_stock_limit(self, stock_id, start_time, end_time):
"""Parameter
stock_id
trade_date
is limtited
"""
return self.quote[(stock_id, trade_date)]["limit"]
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0]
def check_stock_suspended(self, stock_id, trade_date):
def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
return (stock_id, trade_date) not in self.quote
return sample_feature(self.quote, stock_id, start_time, end_time).empty
def is_stock_tradable(self, stock_id, trade_date):
def is_stock_tradable(self, stock_id, start_time, end_time):
# check if stock can be traded
# same as check in check_order
if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date):
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
return False
else:
return True
def check_order(self, order):
# check limit and suspended
if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit(
order.stock_id, order.trade_date
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
order.stock_id, order.start_time, order.end_time
):
return False
else:
@@ -199,7 +201,7 @@ class Exchange:
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current if trade_account else position
)
@@ -214,24 +216,24 @@ class Exchange:
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]
def get_quote_info(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time)
def get_close(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]["$close"]
def get_close(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0]
def get_deal_price(self, stock_id, trade_date):
deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, trade_date)
deal_price = self.get_close(stock_id, start_time, end_time)
return deal_price
def get_factor(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]["$factor"]
def get_factor(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0]
def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date):
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
@@ -246,7 +248,7 @@ class Exchange:
# calculate the total weight of tradable value
tradable_weight = 0.0
for stock_id in weight_position:
if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
# weight_position must be greater than 0 and less than 1
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
raise ValueError(
@@ -260,12 +262,12 @@ class Exchange:
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(stock_id=stock_id, trade_date=trade_date)
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
)
return amount_dict
@@ -292,7 +294,7 @@ class Exchange:
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return -deal_amount
def generate_order_for_target_amount_position(self, target_position, current_position, trade_date):
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
"""Parameter:
target_position : dict { stock_id : amount }
current_postion : dict { stock_id : amount}
@@ -315,12 +317,12 @@ class Exchange:
for stock_id in sorted_ids:
# Do not generate order for the nontradable stocks
if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
continue
target_amount = target_position.get(stock_id, 0)
current_amount = current_position.get(stock_id, 0)
factor = self.quote[(stock_id, trade_date)]["$factor"]
factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
@@ -332,7 +334,8 @@ class Exchange:
stock_id=stock_id,
amount=deal_amount,
direction=Order.BUY,
trade_date=trade_date,
start_time=start_time,
end_time=end_time,
factor=factor,
)
)
@@ -343,14 +346,15 @@ class Exchange:
stock_id=stock_id,
amount=abs(deal_amount),
direction=Order.SELL,
trade_date=trade_date,
start_time=start_time,
end_time=end_time,
factor=factor,
)
)
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False):
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
@@ -358,10 +362,10 @@ class Exchange:
value = 0
for stock_id in amount_dict:
if (
self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False
and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
):
value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id]
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
return value
def round_amount_by_trade_unit(self, deal_amount, factor):
@@ -384,7 +388,7 @@ class Exchange:
:return: trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
if order.direction == Order.SELL:
# sell
if position is not None:

View File

@@ -0,0 +1,15 @@
class BaseInterpreter:
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter:
@staticmethod
def interpret(action, **kwargs):
return action
class StateInterpreter:
@staticmethod
def interpret(state, **kwargs):
return state

View File

@@ -7,7 +7,7 @@ class Order:
SELL = 0
BUY = 1
def __init__(self, stock_id, amount, trade_date, direction, factor):
def __init__(self, stock_id, amount, start_time, end_time, direction, factor):
"""Parameter
direction : Order.SELL for sell; Order.BUY for buy
stock_id : str
@@ -24,6 +24,7 @@ class Order:
self.amount = amount
# amount of successfully completed orders
self.deal_amount = 0
self.trade_date = trade_date
self.start_time = start_time
self.end_time = end_time
self.direction = direction
self.factor = factor

View File

@@ -28,13 +28,13 @@ a typical example is :{
class Position:
"""Position"""
def __init__(self, cash=0, position_dict={}, today_account_value=0):
def __init__(self, cash=0, position_dict={}, now_account_value=0):
# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
self.position["now_account_value"] = now_account_value
def init_stock(self, stock_id, amount, price=None):
self.position[stock_id] = {}
@@ -82,7 +82,7 @@ class Position:
# SELL
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
else:
raise NotImplementedError("do not suppotr order direction {}".format(order.direction))
raise NotImplementedError("do not support order direction {}".format(order.direction))
def update_stock_price(self, stock_id, price):
self.position[stock_id]["price"] = price
@@ -109,7 +109,7 @@ class Position:
return value
def get_stock_list(self):
stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value"})
return stock_list
def get_stock_price(self, code):
@@ -163,16 +163,17 @@ class Position:
for stock_code, weight in weight_dict.items():
self.update_stock_weight(stock_code, weight)
def save_position(self, path, last_trade_date):
def save_position(self, path, last_trade_time):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None
cash["now_account_value"] = p["now_account_value"]
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
del p["cash"]
del p["today_account_value"]
del p["now_account_value"]
positions = pd.DataFrame.from_dict(p, orient="index")
with pd.ExcelWriter(path) as writer:
positions.to_excel(writer, sheet_name="position")
@@ -189,10 +190,10 @@ class Position:
'weight': <the security weight of total position value>,
sheet "cash"
index: ['init_cash', 'cash', 'today_account_value']
index: ['init_cash', 'cash', 'now_account_value']
'init_cash': <inital cash when account was created>,
'cash': <current cash in account>,
'today_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
'now_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
"""
path = pathlib.Path(path)
positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
@@ -200,14 +201,17 @@ class Position:
positions = positions.to_dict(orient="index")
init_cash = cash_record.loc["init_cash"].values[0]
cash = cash_record.loc["cash"].values[0]
today_account_value = cash_record.loc["today_account_value"].values[0]
last_trade_date = cash_record.loc["last_trade_date"].values[0]
now_account_value = cash_record.loc["now_account_value"].values[0]
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
# assign values
self.position = {}
self.init_cash = init_cash
self.position = positions
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
self.position["now_account_value"] = now_account_value
return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date)
last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
return last_trade_start_time, last_trade_end_time

View File

@@ -21,20 +21,20 @@ class Report:
self.costs = OrderedDict() # trade cost for each trade date
self.values = OrderedDict() # value for each trade date
self.cashes = OrderedDict()
self.latest_report_date = None # pd.TimeStamp
self.latest_report_time = None # pd.TimeStamp
def is_empty(self):
return len(self.accounts) == 0
def get_latest_date(self):
return self.latest_report_date
return self.latest_report_time
def get_latest_account_value(self):
return self.accounts[self.latest_report_date]
return self.accounts[self.latest_report_time]
def update_report_record(
self,
trade_date=None,
trade_time=None,
account_value=None,
cash=None,
return_rate=None,
@@ -44,7 +44,7 @@ class Report:
):
# check data
if None in [
trade_date,
trade_time,
account_value,
cash,
return_rate,
@@ -56,14 +56,14 @@ class Report:
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
)
# update report data
self.accounts[trade_date] = account_value
self.returns[trade_date] = return_rate
self.turnovers[trade_date] = turnover_rate
self.costs[trade_date] = cost_rate
self.values[trade_date] = stock_value
self.cashes[trade_date] = cash
self.accounts[trade_time] = account_value
self.returns[trade_time] = return_rate
self.turnovers[trade_time] = turnover_rate
self.costs[trade_time] = cost_rate
self.values[trade_time] = stock_value
self.cashes[trade_time] = cash
# update latest_report_date
self.latest_report_date = trade_date
self.latest_report_time = trade_time
# finish daily report update
def generate_report_dataframe(self):
@@ -74,7 +74,7 @@ class Report:
report["cost"] = pd.Series(self.costs)
report["value"] = pd.Series(self.values)
report["cash"] = pd.Series(self.cashes)
report.index.name = "date"
report.index.name = "trade_time"
return report
def save_report(self, path):
@@ -94,13 +94,13 @@ class Report:
index = r.index
self.init_vars()
for date in index:
for trade_time in index:
self.update_report_record(
trade_date=date,
account_value=r.loc[date]["account"],
cash=r.loc[date]["cash"],
return_rate=r.loc[date]["return"],
turnover_rate=r.loc[date]["turnover"],
cost_rate=r.loc[date]["cost"],
stock_value=r.loc[date]["value"],
trade_time=trade_time,
account_value=r.loc[trade_time]["account"],
cash=r.loc[trade_time]["cash"],
return_rate=r.loc[trade_time]["return"],
turnover_rate=r.loc[trade_time]["turnover"],
cost_rate=r.loc[trade_time]["cost"],
stock_value=r.loc[trade_time]["value"],
)

View File

@@ -4,13 +4,13 @@
from .dl_strategy import (
TopkDropoutStrategy,
BaseStrategy,
WeightStrategyBase,
)
from .rule_strategy import(
TWAPStrategy,
SBBEMAStrategy
SBBStrategyBase,
SBBStrategyEMA,
)
from .cost_control import (

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .strategy import WeightStrategyBase
from .dl_strategy import WeightStrategyBase
import copy

View File

@@ -4,12 +4,12 @@ import numpy as np
import pandas as pd
from ...utils import sample_feature
from ...strategy.base import DLStrategy
from ...backtest.order import Order
from ...strategy.base import ModelStrategy
from ..backtest.order import Order
from .order_generator import OrderGenWInteract
class TopkDropoutStrategy(DLStrategy):
class TopkDropoutStrategy(ModelStrategy):
def __init__(
self,
step_bar,
@@ -53,7 +53,7 @@ class TopkDropoutStrategy(DLStrategy):
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time)
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange)
self.topk = topk
self.n_drop = n_drop
self.method_sell = method_sell
@@ -67,9 +67,10 @@ class TopkDropoutStrategy(DLStrategy):
self.only_tradable = only_tradable
def reset(trade_exchange=None, **kwargs):
def reset(self, trade_exchange=None, **kwargs):
super(TopkDropoutStrategy, self).reset(**kwargs)
self.trade_exchange = trade_exchange
if trade_exchange:
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index):
"""get_risk_degree
@@ -189,7 +190,7 @@ class TopkDropoutStrategy(DLStrategy):
# update cash
cash += trade_val - trade_cost
# sold
del self.stock_count[code]
self.stock_count[code] = 0
else:
# no buy signal, but the stock is kept
self.stock_count[code] += 1
@@ -210,10 +211,10 @@ class TopkDropoutStrategy(DLStrategy):
# value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
for code in buy:
# check is stock suspended
if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
continue
# buy order
buy_price = self.trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
buy_amount = value / buy_price
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
@@ -229,8 +230,8 @@ class TopkDropoutStrategy(DLStrategy):
self.stock_count[code] = 1
return sell_order_list + buy_order_list
class WeightStrategyBase(DLStrategy):
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs):
class WeightStrategyBase(ModelStrategy):
def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs):
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
self.trade_exchange = trade_exchange
if isinstance(order_generator_cls_or_obj, type):

View File

@@ -4,8 +4,8 @@
"""
This order generator is for strategies based on WeightStrategyBase
"""
from ...backtest.position import Position
from ...backtest.exchange import Exchange
from ..backtest.position import Position
from ..backtest.exchange import Exchange
import pandas as pd
import copy

View File

@@ -4,18 +4,20 @@ import numpy as np
import pandas as pd
from ...utils import sample_feature
from ...data.data import D
from ...strategy.base import RuleStrategy, TradingEnhancement
from ...backtest.order import Order
from ..backtest.order import Order
class TWAPStrategy(RuleStrategy, TradingEnhancement):
def reset(self, trade_order_list=None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
TradingEnhancement.reset(trade_order_list=trade_order_list)
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list:
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
def generate_order_list(self, **kwargs):
@@ -43,13 +45,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
TREND_LONG = 2
def reset(self, trade_order_list=None, **kwargs):
TradingEnhancement.reset(trade_order_list=trade_order_list)
self.trade_amount = {}
self.trade_delay = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
self.trade_trend[(order.stock_id, order.direction)] = TREND_MID
super(SBBStrategyBase, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list:
self.trade_amount = {}
self.trade_trend = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
@@ -64,7 +68,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
_pred_trend = self._pred_price_trend(order.stock_id)
else:
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
if _pred_trend == TREND_MID:
if _pred_trend == self.TREND_MID:
_order = Order(
stock_id=order.stock_id,
amount=self.trade_amount[(order.stock_id, order.direction)],
@@ -97,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
factor=order.factor,
)
order_list.append(_order)
if self.trade_index % 2 == 1
if self.trade_index % 2 == 1:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
@@ -110,8 +114,8 @@ class SBBStrategyEMA(SBBStrategyBase):
def __init__(
self,
step_bar,
start_time,
end_time,
start_time=None,
end_time=None,
instruments="csi300",
freq="day",
**kwargs,
@@ -121,21 +125,23 @@ class SBBStrategyEMA(SBBStrategyBase):
warnings.warn("`instruments` is not set, will load all stocks")
self.instruments = "all"
if isinstance(instruments, str):
self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
self.instruments = D.instruments(instruments)
self.freq = freq
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
fields = [("EMA($close, 10) - EMA($close, 20)", "signal")]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
def _reset_trade_calendar(self, start_time=None, end_time=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
if self.start_time and self.end_time:
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
self.signal.columns = ["signal"]
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")
if _sample_signal.empty:
return SBBStrategy.TREND_MID
elif _sample_signal.iloc[0, 0] > 0:
return SBBStrategy.TREND_LONG
return self.TREND_MID
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
else:
return SBBStrategy.TREND_SHORT
return self.TREND_SHORT

View File

@@ -117,6 +117,7 @@ class CalendarProvider(abc.ABC):
flag = f"{freq}_sam_{freq_sam}_future_{future}"
if flag in H["c"]:
_calendar, _calendar_index = H["c"][flag]
return _calendar, _calendar_index
else:
flag_raw = f"{freq}_sam_{None}_future_{future}"
if flag_raw in H["c"]:
@@ -125,6 +126,7 @@ class CalendarProvider(abc.ABC):
_calendar = np.array(self.load_calendar(freq, future))
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
H["c"][flag_raw] = _calendar, _calendar_index
if freq_sam is None:
return _calendar, _calendar_index
else:
@@ -132,6 +134,7 @@ class CalendarProvider(abc.ABC):
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
H["c"][flag] = _calendar_sam, _calendar_sam_index
return _calendar_sam, _calendar_sam_index
def _uri(self, start_time, end_time, freq, future=False):
"""Get the uri of calendar generation task."""
@@ -541,8 +544,8 @@ class LocalCalendarProvider(CalendarProvider):
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
def calendar(self, start_time=None, end_time=None, freq="day", future=False, freq_sam=None):
_calendar, _ = self._get_calendar(freq=freq, future=future)
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
_calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
# strip
if start_time:
start_time = pd.Timestamp(start_time)
@@ -764,6 +767,7 @@ class ClientCalendarProvider(CalendarProvider):
self.conn = conn
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
self.conn.send_request(
request_type="calendar",
request_content={

View File

@@ -10,8 +10,9 @@ import pandas as pd
from ..utils import get_sample_freq_calendar
from ..data.dataset import DatasetH
from ..backtest.order import Order
from ..backtest.env import TradeCalendarBase
from ..data.dataset.utils import get_level_index
from ..contrib.backtest.order import Order
from ..contrib.backtest.env import TradeCalendarBase
"""
1. BaseStrategy 的粒度一定是数据粒度的整数倍
@@ -24,26 +25,14 @@ class BaseStrategy(TradeCalendarBase):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs):
def reset(self, start_time=None, end_time=None, **kwargs):
if start_time or end_time :
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def _get_trade_time(self):
if 0 < self.trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
return trade_start_time, trade_end_time
elif self.trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def generate_order_list(self, **kwargs):
self.trade_index = self.trade_index + 1
@@ -52,20 +41,26 @@ class BaseStrategy(TradeCalendarBase):
class RuleStrategy(BaseStrategy):
pass
class DLStrategy(BaseStrategy):
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None):
class ModelStrategy(BaseStrategy):
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs):
self.model = model
self.dataset = dataset
self.pred_scores = self.model.predict(dataset)
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
#pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
super(DLStrategy, self).__init__(step_bar, start_time, end_time)
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
def _update_model(self):
def _convert_index_format(self, df):
if get_level_index(df, level="datetime") == 0:
df = df.swaplevel().sort_index()
return df
def _update_model(self):
"""update pred score
"""
pass
class TradingEnhancement:
def reset(self, trade_order_list):
self.trade_order_list = trade_order_list
def reset(self, trade_order_list=None):
if trade_order_list:
self.trade_order_list = trade_order_list

View File

@@ -15,6 +15,7 @@ import bisect
import shutil
import difflib
import hashlib
import warnings
import datetime
import requests
import tempfile
@@ -918,37 +919,40 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs):
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
from ..data.data import Cal
try:
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
freq, freq_sam = freq, None
except ValueError:
freq_sam = freq
if freq.endswith(("m", "month", "w", "week", "d", "day")):
try:
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
except ValueError:
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs)
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
freq = "day"
elif freq.endswith(("min", "minute")):
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
if instruments and type(instruments) is not list:
if instruments and not isinstance(instruments, list):
instruments = [instruments]
if fields and type(fields) is not list:
fields = [fields]
selector_inst = slice(None) if instruments is None else instruments
selector_datetime = slice(start_time, end_time)
if fields is not None and type(fields) is not list:
fields = [fields]
selector_fields = slice(None) if fields is None else fields
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
if isinstance(feature, pd.Series):
feature = feature.loc[(selector_inst, selector_datetime)]
if fields:
warnings.warn(f"sample series feature, {fields} is ignored!")
elif isinstance(feature, pd.DataFrame):
selector_fields = slice(None) if fields is None else fields
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
if method:
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else: