mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix bug
This commit is contained in:
@@ -7,13 +7,9 @@ from pathlib import Path
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.backtest import backtest
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.backtest import backtest
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -67,9 +63,9 @@ if __name__ == "__main__":
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"train": ("2012-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"test": ("2017-01-01", "2018-01-31"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -79,41 +75,40 @@ if __name__ == "__main__":
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time)
|
||||
trade_start_time = "2017-01-31"
|
||||
trade_end_time = "2018-01-31"
|
||||
|
||||
backtest_config={
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.dl_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"step_bar": "week",
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"trade_exchange": trade_exchange,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"env":{
|
||||
"class": "SplitEnv",
|
||||
"module_path": "qlib.backtest.env",
|
||||
"module_path": "qlib.contrib.backtest.env",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"step_bar": "week",
|
||||
"sub_env": {
|
||||
"class": "SimulatorEnv",
|
||||
"module_path": "qlib.backtest.env",
|
||||
"module_path": "qlib.contrib.backtest.env",
|
||||
"kwargs": {
|
||||
"step_bar": "1min",
|
||||
"trade_exchange": trade_exchange,
|
||||
"step_bar": "day",
|
||||
}
|
||||
},
|
||||
"sub_strategy": {
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "1min",
|
||||
"step_bar": "day",
|
||||
"freq": "day",
|
||||
"instruments": "csi300",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -121,4 +116,4 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
|
||||
backtest(**backtest_config, )
|
||||
report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False)
|
||||
@@ -1,130 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
exchange = Exchange(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
trade_unit=trade_unit,
|
||||
min_cost=min_cost,
|
||||
)
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
else:
|
||||
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(root_instance, "sub_env"):
|
||||
setup_exchange(root_instance.sub_env, trade_exchange)
|
||||
if hasattr(root_instance, "sub_strategy"):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
setup_exchange(trade_env, trade_exchange)
|
||||
setup_exchange(trade_strategy, trade_exchange)
|
||||
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
|
||||
|
||||
return report_dict
|
||||
@@ -1,170 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
rtn:
|
||||
from order's view
|
||||
1.change if any order is executed, sell order or buy order
|
||||
2.change at the end of today, (today_clse - stock_price) * amount
|
||||
earning
|
||||
from value of current position
|
||||
earning will be updated at the end of trade date
|
||||
earning = today_value - pre_value
|
||||
**is consider cost**
|
||||
while earning is the difference of two position value, so it considers cost, it is the true return rate
|
||||
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
|
||||
"""
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, last_trade_time=None):
|
||||
self.init_vars(init_cash, last_trade_time)
|
||||
|
||||
def init_vars(self, init_cash, last_trade_time=None):
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current = Position(cash=init_cash)
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.val = 0
|
||||
self.report = Report()
|
||||
self.earning = 0
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
|
||||
def get_cash(self):
|
||||
return self.current.position["cash"]
|
||||
|
||||
def update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
# update turnover
|
||||
self.to += trade_val
|
||||
# update cost
|
||||
self.ct += cost
|
||||
# update return
|
||||
# update self.rtn from order
|
||||
trade_amount = trade_val / trade_price
|
||||
if order.direction == Order.SELL: # 0 for sell
|
||||
# when sell stock, get profit from price change
|
||||
profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
|
||||
self.rtn += profit # note here do not consider cost
|
||||
elif order.direction == Order.BUY: # 1 for buy
|
||||
# when buy stock, we get return for the rtn computing method
|
||||
# profit in buy order is to make self.rtn is consistent with self.earning at the end of date
|
||||
profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.rtn += profit
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
trade_amount = trade_val / trade_price
|
||||
if order.direction == Order.SELL:
|
||||
# sell stock
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
# update current position
|
||||
# for may sell all of stock_id
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
else:
|
||||
# buy stock
|
||||
# deal order, then update state
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
quote: pd.DataFrame (code, date), collumns
|
||||
when the end of trade date
|
||||
- update rtn
|
||||
- update price for each asset
|
||||
- update value for this account
|
||||
- update earning (2nd view of return )
|
||||
- update holding day, count of stock
|
||||
- update position hitory
|
||||
- update report
|
||||
:return: None
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
stock_list = self.current.get_stock_list()
|
||||
profit = 0
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
# update value
|
||||
self.val = self.current.calculate_value()
|
||||
# update earning (2nd view of return)
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, now_account_value, now_stock_value
|
||||
if self.report.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
else:
|
||||
last_account_value = self.report.get_latest_account_value()
|
||||
now_account_value = self.current.calculate_value()
|
||||
now_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = now_account_value - last_account_value
|
||||
# update report for today
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
trade_time=trade_start_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
# here use earning to calculate return, position's view, earning consider cost, true return
|
||||
# in order to make same definition with original backtest in evaluate.py
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
self.current.update_weight_all()
|
||||
# update positions
|
||||
# note use deepcopy
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
|
||||
# finish today's updation
|
||||
# reset the daily variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.last_trade_time = (trade_start_time, trade_end_time)
|
||||
|
||||
def load_account(self, account_path):
|
||||
report = Report()
|
||||
position = Position()
|
||||
last_trade_time = position.load_position(account_path / "position.xlsx")
|
||||
report.load_report(account_path / "report.csv")
|
||||
|
||||
# assign values
|
||||
self.init_vars(position.init_cash)
|
||||
self.current = position
|
||||
self.report = report
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def save_account(self, account_path):
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
|
||||
self.report.save_report(account_path / "report.csv")
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
|
||||
def backtest(trade_strategy, trade_env, benchmark, account):
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
positions = trade_account.get_positions()
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
|
||||
return report_dict
|
||||
|
||||
@@ -1,429 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import random
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data import D
|
||||
from ..utils import sample_feature
|
||||
from .order import Order
|
||||
from ..config import C, REG_CN
|
||||
from ..log import get_module_logger
|
||||
|
||||
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
deal_price=None,
|
||||
subscribe_fields=[],
|
||||
limit_threshold=None,
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
trade_unit=None,
|
||||
min_cost=5,
|
||||
extra_quote=None,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param start_time: start time for backtest
|
||||
:param end_time: end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
:param subscribe_fields: list, subscribe fields
|
||||
:param limit_threshold: float, 0.1 for example, default None
|
||||
:param open_cost: cost rate for open, default 0.0015
|
||||
:param close_cost: cost rate for close, default 0.0025
|
||||
:param trade_unit: trade unit, 100 for China A market
|
||||
:param min_cost: min cost, default 5
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$factor', 'limit'].
|
||||
The limit indicates that the etf is tradable on a specific day.
|
||||
Necessary fields:
|
||||
$close is for calculating the total value at end of each day.
|
||||
Optional fields:
|
||||
$vwap is only necessary when we use the $vwap price as the deal price
|
||||
$factor is for rounding to the trading unit
|
||||
limit will be set to False by default(False indicates we can buy this
|
||||
target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
self.trade_unit = trade_unit
|
||||
|
||||
# TODO: the quote, trade_dates, codes are not necessray.
|
||||
# It is just for performance consideration.
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif abs(limit_threshold) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
if deal_price[0] != "$":
|
||||
self.deal_price = "$" + deal_price
|
||||
else:
|
||||
self.deal_price = deal_price
|
||||
if isinstance(codes, str):
|
||||
codes = D.instruments(codes)
|
||||
self.codes = codes
|
||||
# Necessary fields
|
||||
# $close is for calculating the total value at end of each day.
|
||||
# $factor is for rounding to the trading unit
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
|
||||
subscribe_fields = list(necessary_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(subscribe_fields))
|
||||
self.all_fields = all_fields
|
||||
self.open_cost = open_cost
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.limit_threshold = limit_threshold
|
||||
|
||||
|
||||
self.extra_quote = extra_quote
|
||||
self.set_quote(codes, start_time, end_time)
|
||||
|
||||
def set_quote(self, codes, start_time, end_time):
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
if self.quote[self.deal_price].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(self.deal_price))
|
||||
|
||||
if self.quote["$factor"].isna().any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
# Use adjusted price
|
||||
self.trade_w_adj_price = True
|
||||
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
|
||||
else:
|
||||
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
|
||||
# Use normal price
|
||||
self.trade_w_adj_price = False
|
||||
# update limit
|
||||
# check limit_threshold
|
||||
if self.limit_threshold is None:
|
||||
self.quote["limit"] = False
|
||||
else:
|
||||
# set limit
|
||||
self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold)
|
||||
|
||||
quote_df = self.quote
|
||||
if self.extra_quote is not None:
|
||||
# process extra_quote
|
||||
if "$close" not in self.extra_quote:
|
||||
raise ValueError("$close is necessray in extra_quote")
|
||||
if self.deal_price not in self.extra_quote.columns:
|
||||
self.extra_quote[self.deal_price] = self.extra_quote["$close"]
|
||||
self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
|
||||
if "$factor" not in self.extra_quote.columns:
|
||||
self.extra_quote["$factor"] = 1.0
|
||||
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
|
||||
if "limit" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit"] = False
|
||||
self.logger.warning("No limit set for extra_quote. All stock will be tradable.")
|
||||
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
# update quote: pd.DataFrame to dict, for search use
|
||||
self.quote = quote_df
|
||||
|
||||
def _update_limit(self, buy_limit, sell_limit):
|
||||
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time):
|
||||
"""Parameter
|
||||
stock_id
|
||||
trade_date
|
||||
is limtited
|
||||
"""
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0, 0]
|
||||
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time).empty is False
|
||||
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time):
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def check_order(self, order):
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def deal_order(self, order, trade_account=None, position=None):
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
|
||||
:param order: Deal the order.
|
||||
:param trade_account: Trade account to be updated after dealing the order.
|
||||
:param position: position to be updated after dealing the order.
|
||||
:return: trade_val, trade_cost, trade_price
|
||||
"""
|
||||
# need to check order first
|
||||
# TODO: check the order unit limit in the exchange!!!!
|
||||
# The order limit is related to the adj factor and the cur_amount.
|
||||
# factor = self.quote[(order.stock_id, order.trade_date)]['$factor']
|
||||
# cur_amount = trade_account.current.get_stock_amount(order.stock_id)
|
||||
if self.check_order(order) is False:
|
||||
raise AttributeError("need to check order first")
|
||||
if trade_account is not None and position is not None:
|
||||
raise ValueError("trade_account and position can only choose one")
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position
|
||||
)
|
||||
# update account
|
||||
if trade_val > 0:
|
||||
# If the order can only be deal 0 trade_val. Nothing to be updated
|
||||
# Otherwise, it will result some stock with 0 amount in the position
|
||||
if trade_account:
|
||||
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
elif position:
|
||||
position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last")
|
||||
deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last")
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
cash : cash
|
||||
trade_date : trade date
|
||||
"""
|
||||
|
||||
# calculate the total weight of tradable value
|
||||
tradable_weight = 0.0
|
||||
for stock_id in weight_position:
|
||||
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
# weight_position must be greater than 0 and less than 1
|
||||
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
|
||||
raise ValueError(
|
||||
"weight_position is {}, "
|
||||
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
|
||||
)
|
||||
tradable_weight += weight_position[stock_id]
|
||||
|
||||
if tradable_weight - 1.0 >= 1e-5:
|
||||
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
|
||||
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount, target_amount, factor):
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
|
||||
:param current_amount:
|
||||
:param target_amount:
|
||||
:param factor:
|
||||
:return real_deal_amount; Positive deal_amount indicates buying more stock.
|
||||
"""
|
||||
if current_amount == target_amount:
|
||||
return 0
|
||||
elif current_amount < target_amount:
|
||||
deal_amount = target_amount - current_amount
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return deal_amount
|
||||
else:
|
||||
if target_amount == 0:
|
||||
return -current_amount
|
||||
else:
|
||||
deal_amount = current_amount - target_amount
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
"""Parameter:
|
||||
target_position : dict { stock_id : amount }
|
||||
current_postion : dict { stock_id : amount}
|
||||
trade_unit : trade_unit
|
||||
down sample : for amount 321 and trade_unit 100, deal_amount is 300
|
||||
deal order on trade_date
|
||||
"""
|
||||
# split buy and sell for further use
|
||||
buy_order_list = []
|
||||
sell_order_list = []
|
||||
# three parts: kept stock_id, dropped stock_id, new stock_id
|
||||
# handle kept stock_id
|
||||
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
|
||||
# so here we sort stock_id, and then randomly shuffle the order of stock_id
|
||||
# because the same random seed is used, the final stock_id order is fixed
|
||||
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
|
||||
random.seed(0)
|
||||
random.shuffle(sorted_ids)
|
||||
for stock_id in sorted_ids:
|
||||
|
||||
# Do not generate order for the nontradable stocks
|
||||
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
continue
|
||||
|
||||
target_amount = target_position.get(stock_id, 0)
|
||||
current_amount = current_position.get(stock_id, 0)
|
||||
factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
|
||||
|
||||
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
|
||||
if deal_amount == 0:
|
||||
continue
|
||||
elif deal_amount > 0:
|
||||
# buy stock
|
||||
buy_order_list.append(
|
||||
Order(
|
||||
stock_id=stock_id,
|
||||
amount=deal_amount,
|
||||
direction=Order.BUY,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# sell stock
|
||||
sell_order_list.append(
|
||||
Order(
|
||||
stock_id=stock_id,
|
||||
amount=abs(deal_amount),
|
||||
direction=Order.SELL,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
)
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
"""
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
):
|
||||
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
|
||||
return value
|
||||
|
||||
def round_amount_by_trade_unit(self, deal_amount, factor):
|
||||
"""Parameter
|
||||
deal_amount : float, adjusted amount
|
||||
factor : float, adjusted factor
|
||||
return : float, real amount
|
||||
"""
|
||||
if not self.trade_w_adj_price:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position):
|
||||
"""
|
||||
Calculation of trade info
|
||||
|
||||
:param order:
|
||||
:param position: Position
|
||||
:return: trade_val, trade_cost
|
||||
"""
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
if np.isclose(order.amount, position.get_stock_amount(order.stock_id)):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
order.deal_amount = order.amount
|
||||
else:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
else:
|
||||
# TODO: We don't know current position.
|
||||
# We choose to sell all
|
||||
order.deal_amount = order.amount
|
||||
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * self.close_cost, self.min_cost)
|
||||
elif order.direction == Order.BUY:
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.amount * trade_price
|
||||
if cash < trade_val * (1 + self.open_cost):
|
||||
# The money is not enough
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
cash / (1 + self.open_cost) / trade_price, order.factor
|
||||
)
|
||||
else:
|
||||
# THe money is enough
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
else:
|
||||
# Unknown amount of money. Just round the amount
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = trade_val * self.open_cost
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
|
||||
return trade_val, trade_cost
|
||||
@@ -1,90 +0,0 @@
|
||||
class HighFreqOrderNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time, feature_save_dir, price_dim=5, order_price_dim=2, volume_dim=1, order_volume_dim=8, day_length=240):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.price_dim = price_dim
|
||||
self.volume_dim = volume_dim
|
||||
self.order_price_dim = order_price_dim
|
||||
self.order_volume_dim = order_volume_dim
|
||||
self.feature_save_dir = feature_save_dir
|
||||
self.day_length = day_length
|
||||
self.names = dict()
|
||||
column_dim = self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim
|
||||
fields = [("price", self.price_dim), ("order_price", self.order_price_dim), ("volume", self.volume_dim), ("order_volume", self.order_volume_dim)]
|
||||
last_dim = 0
|
||||
for field, field_dim in fields:
|
||||
self.names[field] = list(range(last_dim, last_dim + field_dim)) + list((range(column_dim + last_dim, column_dim + last_dim + field_dim)))
|
||||
last_dim += field_dim
|
||||
|
||||
@profile
|
||||
def fit(self, df_features):
|
||||
# fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
|
||||
|
||||
print("end")
|
||||
if not os.path.exists(self.feature_save_dir):
|
||||
os.makedirs(self.feature_save_dir)
|
||||
for name, name_val in self.names.items():
|
||||
print(name)
|
||||
df_values = df_features.iloc(axis=1)[name_val].values
|
||||
if name == "volume" or name == "order_volume":
|
||||
df_values = np.log1p(df_values)
|
||||
self.feature_med = np.nanmedian(df_values)
|
||||
np.save(self.feature_save_dir + name + "_med.npy", self.feature_med)
|
||||
df_values = df_values - self.feature_med
|
||||
self.feature_std = np.nanmedian(np.absolute(df_values)) * 1.4826 + 1e-12
|
||||
np.save(self.feature_save_dir + name + "_std.npy", self.feature_std)
|
||||
df_values = df_values / self.feature_std
|
||||
np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values))
|
||||
np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values))
|
||||
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
df_values_dict = dict()
|
||||
for name, name_val in self.names.items():
|
||||
self.feature_med = np.load(self.feature_save_dir + name + "_med.npy")
|
||||
self.feature_std = np.load(self.feature_save_dir + name + "_std.npy")
|
||||
self.feature_vmax = np.load(self.feature_save_dir + name + "_vmax.npy")
|
||||
self.feature_vmin = np.load(self.feature_save_dir + name + "_vmin.npy")
|
||||
|
||||
df_values = df_features.iloc(axis=1)[name_val].values
|
||||
if name == "volume" or name == "order_volume":
|
||||
df_values[:] = np.log1p(df_values)
|
||||
df_values[:] -= self.feature_med
|
||||
df_values[:] /= self.feature_std
|
||||
slice0 = df_values > 3.0
|
||||
slice1 = df_values > 3.5
|
||||
slice2 = df_values < -3.0
|
||||
slice3 = df_values < -3.5
|
||||
|
||||
df_values[slice0] = (
|
||||
3.0 + (df_values[slice0] - 3.0) / (self.feature_vmax - 3) * 0.5
|
||||
)
|
||||
df_values[slice1] = 3.5
|
||||
df_values[slice2] = (
|
||||
-3.0 - (df_values[slice2] + 3.0) / (self.feature_vmin + 3) * 0.5
|
||||
)
|
||||
df_values[slice3] = -3.5
|
||||
df_values_dict[name] = df_values
|
||||
|
||||
idx = df_features.index.droplevel("datetime").drop_duplicates()
|
||||
idx.set_names(["instrument", "datetime"], inplace=True)
|
||||
|
||||
# Reshape is specifically for adapting to RL high-freq executor
|
||||
feat = df_values[:, list(range(self.price_dim)) + list(range(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim))
|
||||
+ list(range((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim))
|
||||
+ list(range((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim))
|
||||
].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length)
|
||||
|
||||
feat_1 = df_values[:, list(np.arange(self.price_dim) + self.price_dim) + list(np.arange(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim) + self.order_price_dim)
|
||||
+ list(np.arange((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim) + self.volume_dim)
|
||||
+ list(np.arange((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim) + self.order_volume_dim)
|
||||
].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length)
|
||||
df_new_features = pd.DataFrame(
|
||||
data=np.concatenate((feat, feat_1), axis=1),
|
||||
index=idx,
|
||||
columns=range(2 * (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length),
|
||||
).sort_index()
|
||||
return df_new_features
|
||||
@@ -1,132 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
exchange = Exchange(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
trade_unit=trade_unit,
|
||||
min_cost=min_cost,
|
||||
)
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
trade_account = Account(init_cash=account)
|
||||
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# temp_env = trade_env
|
||||
# while True:
|
||||
# if hasattr(temp_env, "trade_exchange"):
|
||||
# temp_env.reset(trade_exchange=trade_exchange)
|
||||
# if hasattr(temp_env, "sub_env"):
|
||||
# temp_env = temp_env.sub_env
|
||||
# else:
|
||||
# break
|
||||
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_state, _reset_info = self.sub_env.get_first_state()
|
||||
trade_strategy.reset(**_reset_info)
|
||||
|
||||
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
return
|
||||
@@ -1,30 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
class Order:
|
||||
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
def __init__(self, stock_id, amount, start_time, end_time, direction, factor):
|
||||
"""Parameter
|
||||
direction : Order.SELL for sell; Order.BUY for buy
|
||||
stock_id : str
|
||||
amount : float
|
||||
trade_date : pd.Timestamp
|
||||
factor : float
|
||||
presents the weight factor assigned in Exchange()
|
||||
"""
|
||||
# check direction
|
||||
if direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.stock_id = stock_id
|
||||
# amount of generated orders
|
||||
self.amount = amount
|
||||
# amount of successfully completed orders
|
||||
self.deal_amount = 0
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.direction = direction
|
||||
self.factor = factor
|
||||
@@ -1,217 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
import pathlib
|
||||
from .order import Order
|
||||
|
||||
"""
|
||||
Position module
|
||||
"""
|
||||
|
||||
"""
|
||||
current state of position
|
||||
a typical example is :{
|
||||
<instrument_id>: {
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
},
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Position:
|
||||
"""Position"""
|
||||
|
||||
def __init__(self, cash=0, position_dict={}, today_account_value=0):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
self.position["cash"] = cash
|
||||
self.position["today_account_value"] = today_account_value
|
||||
|
||||
def init_stock(self, stock_id, amount, price=None):
|
||||
self.position[stock_id] = {}
|
||||
self.position[stock_id]["count"] = 0 # update count in the end of this date
|
||||
self.position[stock_id]["amount"] = amount
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
else:
|
||||
# exist, add amount
|
||||
self.position[stock_id]["amount"] += trade_amount
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
self.position[stock_id]["amount"] -= trade_amount
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
|
||||
self.del_stock(stock_id)
|
||||
|
||||
self.position["cash"] += trade_val - cost
|
||||
|
||||
def del_stock(self, stock_id):
|
||||
del self.position[stock_id]
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
self.buy_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
elif order.direction == Order.SELL:
|
||||
# SELL
|
||||
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
def update_stock_price(self, stock_id, price):
|
||||
self.position[stock_id]["price"] = price
|
||||
|
||||
def update_stock_count(self, stock_id, count):
|
||||
self.position[stock_id]["count"] = count
|
||||
|
||||
def update_stock_weight(self, stock_id, weight):
|
||||
self.position[stock_id]["weight"] = weight
|
||||
|
||||
def update_cash(self, cash):
|
||||
self.position["cash"] = cash
|
||||
|
||||
def calculate_stock_value(self):
|
||||
stock_list = self.get_stock_list()
|
||||
value = 0
|
||||
for stock_id in stock_list:
|
||||
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
|
||||
return value
|
||||
|
||||
def calculate_value(self):
|
||||
value = self.calculate_stock_value()
|
||||
value += self.position["cash"]
|
||||
return value
|
||||
|
||||
def get_stock_list(self):
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
|
||||
return stock_list
|
||||
|
||||
def get_stock_price(self, code):
|
||||
return self.position[code]["price"]
|
||||
|
||||
def get_stock_amount(self, code):
|
||||
return self.position[code]["amount"]
|
||||
|
||||
def get_stock_count(self, code):
|
||||
return self.position[code]["count"]
|
||||
|
||||
def get_stock_weight(self, code):
|
||||
return self.position[code]["weight"]
|
||||
|
||||
def get_cash(self):
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
d[stock_code] = self.get_stock_amount(code=stock_code)
|
||||
return d
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=False):
|
||||
"""get_stock_weight_dict
|
||||
generate stock weight fict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
|
||||
:param only_stock: If only_stock=True, the weight of each stock in total stock will be returned
|
||||
If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
|
||||
"""
|
||||
if only_stock:
|
||||
position_value = self.calculate_stock_value()
|
||||
else:
|
||||
position_value = self.calculate_value()
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
|
||||
return d
|
||||
|
||||
def add_count_all(self):
|
||||
stock_list = self.get_stock_list()
|
||||
for code in stock_list:
|
||||
self.position[code]["count"] += 1
|
||||
|
||||
def update_weight_all(self):
|
||||
weight_dict = self.get_stock_weight_dict()
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def save_position(self, path, last_trade_time):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
|
||||
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
|
||||
del p["cash"]
|
||||
del p["today_account_value"]
|
||||
positions = pd.DataFrame.from_dict(p, orient="index")
|
||||
with pd.ExcelWriter(path) as writer:
|
||||
positions.to_excel(writer, sheet_name="position")
|
||||
cash.to_excel(writer, sheet_name="info")
|
||||
|
||||
def load_position(self, path):
|
||||
"""load position information from a file
|
||||
should have format below
|
||||
sheet "position"
|
||||
columns: ['stock', 'count', 'amount', 'price', 'weight']
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
|
||||
sheet "cash"
|
||||
index: ['init_cash', 'cash', 'today_account_value']
|
||||
'init_cash': <inital cash when account was created>,
|
||||
'cash': <current cash in account>,
|
||||
'today_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
|
||||
cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0)
|
||||
positions = positions.to_dict(orient="index")
|
||||
init_cash = cash_record.loc["init_cash"].values[0]
|
||||
cash = cash_record.loc["cash"].values[0]
|
||||
today_account_value = cash_record.loc["today_account_value"].values[0]
|
||||
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
|
||||
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
|
||||
|
||||
# assign values
|
||||
self.position = {}
|
||||
self.init_cash = init_cash
|
||||
self.position = positions
|
||||
self.position["cash"] = cash
|
||||
self.position["today_account_value"] = today_account_value
|
||||
|
||||
last_trade_start_time = None is pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
|
||||
last_trade_end_time = None is pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
|
||||
return last_trade_start_time, last_trade_end_time
|
||||
@@ -1,324 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .position import Position
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_benchmark_weight(
|
||||
bench,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
path=None,
|
||||
):
|
||||
"""get_benchmark_weight
|
||||
|
||||
get the stock weight distribution of the benchmark
|
||||
|
||||
:param bench:
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param path:
|
||||
|
||||
:return: The weight distribution of the the benchmark described by a pandas dataframe
|
||||
Every row corresponds to a trading day.
|
||||
Every column corresponds to a stock.
|
||||
Every cell represents the strategy.
|
||||
|
||||
"""
|
||||
if not path:
|
||||
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
# TODO: the storage of weights should be implemented in a more elegent way
|
||||
# TODO: The benchmark is not consistant with the filename in instruments.
|
||||
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
|
||||
bench_weight_df = bench_weight_df[bench_weight_df["index"] == bench]
|
||||
bench_weight_df["date"] = pd.to_datetime(bench_weight_df["date"])
|
||||
if start_date is not None:
|
||||
bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date]
|
||||
if end_date is not None:
|
||||
bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date]
|
||||
bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0
|
||||
return bench_stock_weight
|
||||
|
||||
|
||||
def get_stock_weight_df(positions):
|
||||
"""get_stock_weight_df
|
||||
:param positions: Given a positions from backtest result.
|
||||
:return: A weight distribution for the position
|
||||
"""
|
||||
stock_weight = []
|
||||
index = []
|
||||
for date in sorted(positions.keys()):
|
||||
pos = positions[date]
|
||||
if isinstance(pos, dict):
|
||||
pos = Position(position_dict=pos)
|
||||
index.append(date)
|
||||
stock_weight.append(pos.get_stock_weight_dict(only_stock=True))
|
||||
return pd.DataFrame(stock_weight, index=index)
|
||||
|
||||
|
||||
def decompose_portofolio_weight(stock_weight_df, stock_group_df):
|
||||
"""decompose_portofolio_weight
|
||||
|
||||
'''
|
||||
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
|
||||
every row corresponds to a day
|
||||
every column corresponds to a stock.
|
||||
Here is an example below.
|
||||
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
|
||||
date
|
||||
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
|
||||
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
|
||||
....
|
||||
:param stock_group_df: a pandas dataframe to describe the stock group.
|
||||
every row corresponds to a day
|
||||
every column corresponds to a stock.
|
||||
the value in the cell repreponds the group id.
|
||||
Here is a example by for stock_group_df for industry. The value is the industry code
|
||||
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
|
||||
datetime
|
||||
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
...
|
||||
:return: Two dict will be returned. The group_weight and the stock_weight_in_group.
|
||||
The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock
|
||||
"""
|
||||
all_group = np.unique(stock_group_df.values.flatten())
|
||||
all_group = all_group[~np.isnan(all_group)]
|
||||
|
||||
group_weight = {}
|
||||
stock_weight_in_group = {}
|
||||
for group_key in all_group:
|
||||
group_mask = stock_group_df == group_key
|
||||
group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1)
|
||||
stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0)
|
||||
return group_weight, stock_weight_in_group
|
||||
|
||||
|
||||
def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
|
||||
"""
|
||||
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
|
||||
every row corresponds to a day
|
||||
every column corresponds to a stock.
|
||||
Here is an example below.
|
||||
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
|
||||
date
|
||||
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
|
||||
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
|
||||
2016-01-07 0.001555 0.001546 0.002772 0.001393 0.002904 NaN
|
||||
2016-01-08 0.001564 0.001527 0.002791 0.001506 0.002948 NaN
|
||||
2016-01-11 0.001597 0.001476 0.002738 0.001493 0.003043 NaN
|
||||
....
|
||||
|
||||
:param stock_group_df: a pandas dataframe to describe the stock group.
|
||||
every row corresponds to a day
|
||||
every column corresponds to a stock.
|
||||
the value in the cell repreponds the group id.
|
||||
Here is a example by for stock_group_df for industry. The value is the industry code
|
||||
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
|
||||
datetime
|
||||
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
2016-01-07 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
2016-01-08 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
2016-01-11 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
|
||||
...
|
||||
|
||||
:param stock_ret_df: a pandas dataframe to describe the stock return.
|
||||
every row corresponds to a day
|
||||
every column corresponds to a stock.
|
||||
the value in the cell repreponds the return of the group.
|
||||
Here is a example by for stock_ret_df.
|
||||
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
|
||||
datetime
|
||||
2016-01-05 0.007795 0.022070 0.099099 0.024707 0.009473 0.016216
|
||||
2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936
|
||||
2016-01-07 -0.001142 0.022544 0.100000 0.004225 0.000651 0.047226
|
||||
2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408
|
||||
2016-01-11 0.023460 0.004959 -0.034384 0.018663 0.014461 0.010962
|
||||
...
|
||||
|
||||
:return: It will decompose the portofolio to the group weight and group return.
|
||||
"""
|
||||
all_group = np.unique(stock_group_df.values.flatten())
|
||||
all_group = all_group[~np.isnan(all_group)]
|
||||
|
||||
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)
|
||||
|
||||
group_ret = {}
|
||||
for group_key in stock_weight_in_group:
|
||||
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
|
||||
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
|
||||
|
||||
temp_stock_ret_df = stock_ret_df[
|
||||
(stock_ret_df.index >= stock_weight_in_group_start_date)
|
||||
& (stock_ret_df.index <= stock_weight_in_group_end_date)
|
||||
]
|
||||
|
||||
group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
|
||||
# If no weight is assigned, then the return of group will be np.nan
|
||||
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan
|
||||
|
||||
group_weight_df = pd.DataFrame(group_weight)
|
||||
group_ret_df = pd.DataFrame(group_ret)
|
||||
return group_weight_df, group_ret_df
|
||||
|
||||
|
||||
def get_daily_bin_group(bench_values, stock_values, group_n):
|
||||
"""get_daily_bin_group
|
||||
Group the values of the stocks of benchmark into several bins in a day.
|
||||
Put the stocks into these bins.
|
||||
|
||||
:param bench_values: A series contains the value of stocks in benchmark.
|
||||
The index is the stock code.
|
||||
:param stock_values: A series contains the value of stocks of your portofolio
|
||||
The index is the stock code.
|
||||
:param group_n: Bins will be produced
|
||||
|
||||
:return: A series with the same size and index as the stock_value.
|
||||
The value in the series is the group id of the bins.
|
||||
The No.1 bin contains the biggest values.
|
||||
"""
|
||||
stock_group = stock_values.copy()
|
||||
|
||||
# get the bin split points based on the daily proportion of benchmark
|
||||
split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1))
|
||||
# Modify the biggest uppper bound and smallest lowerbound
|
||||
split_points[0], split_points[-1] = -np.inf, np.inf
|
||||
for i, (lb, up) in enumerate(zip(split_points, split_points[1:])):
|
||||
stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i
|
||||
return stock_group
|
||||
|
||||
|
||||
def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None):
|
||||
if group_method == "category":
|
||||
# use the value of the benchmark as the category
|
||||
return stock_group_field_df
|
||||
elif group_method == "bins":
|
||||
assert group_n is not None
|
||||
# place the values into `group_n` fields.
|
||||
# Each bin corresponds to a category.
|
||||
new_stock_group_df = stock_group_field_df.copy().loc[
|
||||
bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max()
|
||||
]
|
||||
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
|
||||
bench_values = stock_group_field_df.loc[idx, row[row].index]
|
||||
new_stock_group_df.loc[idx] = get_daily_bin_group(
|
||||
bench_values, stock_group_field_df.loc[idx], group_n=group_n
|
||||
)
|
||||
return new_stock_group_df
|
||||
|
||||
|
||||
def brinson_pa(
|
||||
positions,
|
||||
bench="SH000905",
|
||||
group_field="industry",
|
||||
group_method="category",
|
||||
group_n=None,
|
||||
deal_price="vwap",
|
||||
):
|
||||
"""brinson profit attribution
|
||||
|
||||
:param positions: The position produced by the backtest class
|
||||
:param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used.
|
||||
:param group_field: The field used to set the group for assets allocation.
|
||||
`industry` and `market_value` is often used.
|
||||
:param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation
|
||||
`bin` will split the value into `group_n` bins and each bins represents a group
|
||||
:param group_n: . Only used when group_method == 'bins'.
|
||||
|
||||
:return:
|
||||
A dataframe with three columns: RAA(excess Return of Assets Allocation), RSS(excess Return of Stock Selectino), RTotal(Total excess Return)
|
||||
Every row corresponds to a trading day, the value corresponds to the next return for this trading day
|
||||
The middle info of brinson profit attribution
|
||||
"""
|
||||
# group_method will decide how to group the group_field.
|
||||
dates = sorted(positions.keys())
|
||||
|
||||
start_date, end_date = min(dates), max(dates)
|
||||
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
|
||||
|
||||
# The attributes for allocation will not
|
||||
if not group_field.startswith("$"):
|
||||
group_field = "$" + group_field
|
||||
if not deal_price.startswith("$"):
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
# FIXME: In current version. Some attributes(such as market_value) of some
|
||||
# suspend stock is NAN. So we have to get more date to forward fill the NAN
|
||||
shift_start_date = start_date - datetime.timedelta(days=250)
|
||||
instruments = D.list_instruments(
|
||||
D.instruments(market="all"),
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
as_list=True,
|
||||
)
|
||||
stock_df = D.features(
|
||||
instruments,
|
||||
[group_field, deal_price],
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
freq="day",
|
||||
)
|
||||
stock_df.columns = [group_field, "deal_price"]
|
||||
|
||||
stock_group_field = stock_df[group_field].unstack().T
|
||||
# FIXME: some attributes of some suspend stock is NAN.
|
||||
stock_group_field = stock_group_field.fillna(method="ffill")
|
||||
stock_group_field = stock_group_field.loc[start_date:end_date]
|
||||
|
||||
stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)
|
||||
|
||||
deal_price_df = stock_df["deal_price"].unstack().T
|
||||
deal_price_df = deal_price_df.fillna(method="ffill")
|
||||
|
||||
# NOTE:
|
||||
# The return will be slightly different from the of the return in the report.
|
||||
# Here the position are adjusted at the end of the trading day with close
|
||||
stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1)
|
||||
stock_ret = stock_ret.shift(-1).loc[start_date:end_date]
|
||||
|
||||
port_stock_weight_df = get_stock_weight_df(positions)
|
||||
|
||||
# decomposing the portofolio
|
||||
port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret)
|
||||
bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret)
|
||||
|
||||
# if the group return of the portofolio is NaN, replace it with the market
|
||||
# value
|
||||
mod_port_group_ret_df = port_group_ret_df.copy()
|
||||
mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df
|
||||
|
||||
Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1)
|
||||
Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1)
|
||||
Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1)
|
||||
Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1)
|
||||
|
||||
return (
|
||||
pd.DataFrame(
|
||||
{
|
||||
"RAA": Q2 - Q1, # The excess profit from the assets allocation
|
||||
"RSS": Q3 - Q1, # The excess profit from the stocks selection
|
||||
# The excess profit from the interaction of assets allocation and stocks selection
|
||||
"RIN": Q4 - Q3 - Q2 + Q1,
|
||||
"RTotal": Q4 - Q1, # The totoal excess profit
|
||||
}
|
||||
),
|
||||
{
|
||||
"port_group_ret": port_group_ret_df,
|
||||
"port_group_weight": port_group_weight_df,
|
||||
"bench_group_ret": bench_group_ret_df,
|
||||
"bench_group_weight": bench_group_weight_df,
|
||||
"stock_group": stock_group,
|
||||
"bench_stock_weight": bench_stock_weight,
|
||||
"port_stock_weight": port_stock_weight_df,
|
||||
"stock_ret": stock_ret,
|
||||
},
|
||||
)
|
||||
@@ -1,106 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
import pandas as pd
|
||||
import pathlib
|
||||
|
||||
|
||||
class Report:
|
||||
# daily report of the account
|
||||
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
# update report
|
||||
def __init__(self):
|
||||
self.init_vars()
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account postion value for each trade date
|
||||
self.returns = OrderedDict() # daily return rate for each trade date
|
||||
self.turnovers = OrderedDict() # turnover for each trade date
|
||||
self.costs = OrderedDict() # trade cost for each trade date
|
||||
self.values = OrderedDict() # value for each trade date
|
||||
self.cashes = OrderedDict()
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
return self.latest_report_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
return self.accounts[self.latest_report_time]
|
||||
|
||||
def update_report_record(
|
||||
self,
|
||||
trade_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
turnover_rate=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
trade_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
turnover_rate,
|
||||
cost_rate,
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
)
|
||||
# update report data
|
||||
self.accounts[trade_time] = account_value
|
||||
self.returns[trade_time] = return_rate
|
||||
self.turnovers[trade_time] = turnover_rate
|
||||
self.costs[trade_time] = cost_rate
|
||||
self.values[trade_time] = stock_value
|
||||
self.cashes[trade_time] = cash
|
||||
# update latest_report_date
|
||||
self.latest_report_time = trade_time
|
||||
# finish daily report update
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
report = pd.DataFrame()
|
||||
report["account"] = pd.Series(self.accounts)
|
||||
report["return"] = pd.Series(self.returns)
|
||||
report["turnover"] = pd.Series(self.turnovers)
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
report.index.name = "trade_time"
|
||||
return report
|
||||
|
||||
def save_report(self, path):
|
||||
r = self.generate_report_dataframe()
|
||||
r.to_csv(path)
|
||||
|
||||
def load_report(self, path):
|
||||
"""load report from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
r = pd.read_csv(open(path, "rb"), index_col=0)
|
||||
r.index = pd.DatetimeIndex(r.index)
|
||||
|
||||
index = r.index
|
||||
self.init_vars()
|
||||
for trade_time in index:
|
||||
self.update_report_record(
|
||||
trade_time=trade_time,
|
||||
account_value=r.loc[trade_time]["account"],
|
||||
cash=r.loc[trade_time]["cash"],
|
||||
return_rate=r.loc[trade_time]["return"],
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
)
|
||||
@@ -2,12 +2,12 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ...utils import init_instance_by_config
|
||||
@@ -17,86 +17,11 @@ from ...config import C
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="dropout",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
There will be 3 ways to return a stratgy. Please follow the code.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a strategy.
|
||||
if strategy is None:
|
||||
# 1) create strategy with param `strategy`
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new strategy ")
|
||||
from .. import strategy as strategy_pool
|
||||
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
elif isinstance(strategy, (dict, str)):
|
||||
# 2) create strategy with init_instance_by_config
|
||||
logger.info("Create new strategy ")
|
||||
strategy = init_instance_by_config(strategy)
|
||||
|
||||
from ..strategy.strategy import BaseStrategy
|
||||
|
||||
# else: nothing happens. 3) Use the strategy directly
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
@@ -104,7 +29,6 @@ def get_exchange(
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
@@ -128,9 +52,6 @@ def get_exchange(
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -149,176 +70,61 @@ def get_exchange(
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
min_cost=min_cost,
|
||||
)
|
||||
return exchange
|
||||
return exchange
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"])
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
def get_executor(
|
||||
executor=None,
|
||||
trade_exchange=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""get_executor
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
else:
|
||||
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(root_instance, "sub_env"):
|
||||
setup_exchange(root_instance.sub_env, trade_exchange)
|
||||
if hasattr(root_instance, "sub_strategy"):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
|
||||
There will be 3 ways to return a executor. Please follow the code.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
executor : BaseExecutor
|
||||
executor used in backtest.
|
||||
trade_exchange : Exchange
|
||||
exchange used in executor
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: BaseExecutor
|
||||
an initialized BaseExecutor object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a executor.
|
||||
if executor is None:
|
||||
# 1) create executor with param `executor`
|
||||
logger.info("Create new executor ")
|
||||
from ..online.executor import SimulatorExecutor
|
||||
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
|
||||
elif isinstance(executor, (dict, str)):
|
||||
# 2) create executor with config
|
||||
logger.info("Create new executor ")
|
||||
executor = init_instance_by_config(executor)
|
||||
|
||||
from ..online.executor import BaseExecutor
|
||||
|
||||
# 3) Use the executor directly
|
||||
if not isinstance(executor, BaseExecutor):
|
||||
raise TypeError("Executor not supported")
|
||||
return executor
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
Parameters
|
||||
----------
|
||||
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
return_order : bool
|
||||
whether to return order list
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
exchange_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(**exchange_args)
|
||||
|
||||
# init executor:
|
||||
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
|
||||
setup_exchange(trade_env, trade_exchange)
|
||||
setup_exchange(trade_strategy, trade_exchange)
|
||||
|
||||
# run backtest
|
||||
report_dict = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
return_order=return_order,
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
|
||||
|
||||
positions = report_dict.get("positions")
|
||||
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
|
||||
return report_dict
|
||||
|
||||
@@ -26,10 +26,10 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, last_trade_date=None):
|
||||
self.init_vars(init_cash, last_trade_date)
|
||||
def __init__(self, init_cash, last_trade_time=None):
|
||||
self.init_vars(init_cash, last_trade_time)
|
||||
|
||||
def init_vars(self, init_cash, last_trade_date=None):
|
||||
def init_vars(self, init_cash, last_trade_time=None):
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current = Position(cash=init_cash)
|
||||
@@ -40,7 +40,7 @@ class Account:
|
||||
self.val = 0
|
||||
self.report = Report()
|
||||
self.earning = 0
|
||||
self.last_trade_date = last_trade_date
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -83,9 +83,10 @@ class Account:
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_daily_end(self, today, trader):
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
today: pd.TimeStamp
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
quote: pd.DataFrame (code, date), collumns
|
||||
when the end of trade date
|
||||
- update rtn
|
||||
@@ -102,11 +103,11 @@ class Account:
|
||||
profit = 0
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trader.check_stock_suspended(code, today):
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
@@ -116,54 +117,54 @@ class Account:
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, today_account_value, today_stock_value
|
||||
# get last_account_value, now_account_value, now_stock_value
|
||||
if self.report.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
else:
|
||||
last_account_value = self.report.get_latest_account_value()
|
||||
today_account_value = self.current.calculate_value()
|
||||
today_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = today_account_value - last_account_value
|
||||
now_account_value = self.current.calculate_value()
|
||||
now_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = now_account_value - last_account_value
|
||||
# update report for today
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
trade_date=today,
|
||||
account_value=today_account_value,
|
||||
trade_time=trade_start_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
# here use earning to calculate return, position's view, earning consider cost, true return
|
||||
# in order to make same definition with original backtest in evaluate.py
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=today_stock_value,
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
# set today_account_value to position
|
||||
self.current.position["today_account_value"] = today_account_value
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
self.current.update_weight_all()
|
||||
# update positions
|
||||
# note use deepcopy
|
||||
self.positions[today] = copy.deepcopy(self.current)
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
|
||||
# finish today's updation
|
||||
# reset the daily variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.last_trade_date = today
|
||||
self.last_trade_time = (trade_start_time, trade_end_time)
|
||||
|
||||
def load_account(self, account_path):
|
||||
report = Report()
|
||||
position = Position()
|
||||
last_trade_date = position.load_position(account_path / "position.xlsx")
|
||||
last_trade_time = position.load_position(account_path / "position.xlsx")
|
||||
report.load_report(account_path / "report.csv")
|
||||
|
||||
# assign values
|
||||
self.init_vars(position.init_cash)
|
||||
self.current = position
|
||||
self.report = report
|
||||
self.last_trade_date = last_trade_date if last_trade_date else None
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def save_account(self, account_path):
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_date)
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
|
||||
self.report.save_report(account_path / "report.csv")
|
||||
|
||||
@@ -4,140 +4,24 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from ...utils import get_date_by_shift, get_date_range
|
||||
from ...data import D
|
||||
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...data.dataset.utils import get_level_index
|
||||
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
|
||||
strategy : Strategy()
|
||||
strategy part for backtest
|
||||
trade_exchange : Exchange()
|
||||
exchage for backtest
|
||||
shift : int
|
||||
whether to shift prediction by one day
|
||||
verbose : bool
|
||||
whether to print log
|
||||
account : float
|
||||
init account value
|
||||
benchmark : str/list/pd.Series
|
||||
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
|
||||
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
"""
|
||||
# Convert format if the input format is not expected
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
if isinstance(benchmark, pd.Series):
|
||||
bench = benchmark
|
||||
else:
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
_temp_result = D.features(
|
||||
_codes,
|
||||
["$close/Ref($close,1)-1"],
|
||||
predict_dates[0],
|
||||
get_date_by_shift(predict_dates[-1], shift=shift),
|
||||
disk_cache=1,
|
||||
)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
if return_order:
|
||||
multi_order_list = []
|
||||
# trading apart
|
||||
for pred_date, trade_date in zip(predict_dates, trade_dates):
|
||||
# for loop predict date and trading date
|
||||
# print
|
||||
if verbose:
|
||||
LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
|
||||
|
||||
# 1. Load the score_series at pred_date
|
||||
try:
|
||||
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime", drop=True)[
|
||||
"score"
|
||||
] # pd.Series(index:stock_id, data: score)
|
||||
except KeyError:
|
||||
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
|
||||
score_series = None
|
||||
|
||||
if score_series is not None and score_series.count() > 0: # in case of the scores are all None
|
||||
# 2. Update your strategy (and model)
|
||||
strategy.update(score_series, pred_date, trade_date)
|
||||
|
||||
# 3. Generate order list
|
||||
order_list = strategy.generate_order_list(
|
||||
score_series=score_series,
|
||||
current=trade_account.current,
|
||||
trade_exchange=trade_exchange,
|
||||
pred_date=pred_date,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
else:
|
||||
order_list = []
|
||||
if return_order:
|
||||
multi_order_list.append((trade_account, order_list, trade_date))
|
||||
# 4. Get result after executing order list
|
||||
# NOTE: The following operation will modify order.amount.
|
||||
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
|
||||
trade_info = executor.execute(trade_account, order_list, trade_date)
|
||||
|
||||
# 5. Update account information according to transaction
|
||||
update_account(trade_account, trade_info, trade_exchange, trade_date)
|
||||
|
||||
# generate backtest report
|
||||
trade_state = trade_env.get_init_state()
|
||||
while not trade_env.finished():
|
||||
_order_list = trade_strategy.generate_order_list(**trade_state)
|
||||
print("_order_list", _order_list)
|
||||
trade_state, trade_info = trade_env.execute(_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
report_df["bench"] = bench
|
||||
positions = trade_account.get_positions()
|
||||
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
if return_order:
|
||||
report_dict.update({"order_list": multi_order_list})
|
||||
|
||||
return report_dict
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
trade_info : list of [Order(), float, float, float]
|
||||
(order, trade_val, trade_cost, trade_price), trade_info with out factor
|
||||
trade_exchange : Exchange()
|
||||
used to get the $close_price at trade_date to update account
|
||||
trade_date : pd.Timestamp
|
||||
"""
|
||||
# update account
|
||||
for [order, trade_val, trade_cost, trade_price] in trade_info:
|
||||
if order.deal_amount == 0:
|
||||
continue
|
||||
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
# at the end of trade date, update the account based the $close_price of stocks.
|
||||
trade_account.update_daily_end(today=trade_date, trader=trade_exchange)
|
||||
|
||||
@@ -5,13 +5,13 @@ import json
|
||||
import copy
|
||||
import warnings
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import Logger
|
||||
from ...data import D, Cal
|
||||
from ...utils import get_date_in_file_name
|
||||
from ...utils import get_pre_trading_date
|
||||
from ..backtest.order import Order
|
||||
from ..utils import init_instance_by_config
|
||||
from ...data.data import Cal
|
||||
from ...utils import get_sample_freq_calendar
|
||||
from .order import Order
|
||||
|
||||
|
||||
class TradeCalendarBase:
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
@@ -20,10 +20,10 @@ class TradeCalendarBase:
|
||||
if end_time:
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
if self.start_time and self.end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar)
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
|
||||
self.calendar = _calendar
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
|
||||
_trade_calendar = self.calendar[_start_index, _end_index + 1]
|
||||
_trade_calendar = self.calendar[_start_index: _end_index + 1]
|
||||
if _start_time != self.start_time:
|
||||
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
|
||||
self.start_index = _start_index - 1
|
||||
@@ -40,7 +40,7 @@ class TradeCalendarBase:
|
||||
trade_index = trade_index - shift
|
||||
if 0 < trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1)
|
||||
trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
@@ -68,7 +68,7 @@ class BaseEnv(TradeCalendarBase):
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.verbose = verbose
|
||||
@@ -76,24 +76,24 @@ class BaseEnv(TradeCalendarBase):
|
||||
|
||||
def _get_position(self):
|
||||
return self.trade_account.current
|
||||
|
||||
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
self.trade_account = trade_account
|
||||
if trade_account:
|
||||
self.trade_account = trade_account
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
def get_first_state(self):
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self._get_position()}
|
||||
return init_state
|
||||
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
def execute(self, order_list=None, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
def finished(self):
|
||||
@@ -122,13 +122,13 @@ class SplitEnv(BaseEnv):
|
||||
#if self.track:
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
super(SplitEnv, self).execute(**kwargs)
|
||||
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
_order_list = self.sub_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
#episode_reward += sub_reward
|
||||
_obs = {"current": self._get_position()}
|
||||
@@ -149,11 +149,12 @@ class SimulatorEnv(BaseEnv):
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose)
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs)
|
||||
|
||||
def reset(trade_exchange=None, **kwargs):
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(SimulatorEnv, self).reset(**kwargs)
|
||||
self.trade_exchange=trade_exchange
|
||||
if trade_exchange:
|
||||
self.trade_exchange=trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
"""
|
||||
@@ -162,7 +163,7 @@ class SimulatorEnv(BaseEnv):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
@@ -8,16 +8,19 @@ import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...data import D
|
||||
from .order import Order
|
||||
from ...data.data import D
|
||||
from ...config import C, REG_CN
|
||||
from ...utils import sample_feature
|
||||
from ...log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
|
||||
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
trade_dates=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
deal_price=None,
|
||||
subscribe_fields=[],
|
||||
@@ -30,7 +33,8 @@ class Exchange:
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param trade_dates: list of pd.Timestamp
|
||||
:param start_time: start time for backtest
|
||||
:param end_time: end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
:param subscribe_fields: list, subscribe fields
|
||||
@@ -51,6 +55,8 @@ class Exchange:
|
||||
target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
@@ -91,21 +97,15 @@ class Exchange:
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.limit_threshold = limit_threshold
|
||||
# TODO: the quote, trade_dates, codes are not necessray.
|
||||
# It is just for performance consideration.
|
||||
if trade_dates is not None and len(trade_dates):
|
||||
start_date, end_date = trade_dates[0], trade_dates[-1]
|
||||
else:
|
||||
self.logger.warning("trade_dates have not been assigned, all dates will be loaded")
|
||||
start_date, end_date = None, None
|
||||
|
||||
|
||||
self.extra_quote = extra_quote
|
||||
self.set_quote(codes, start_date, end_date)
|
||||
self.set_quote(codes, start_time, end_time)
|
||||
|
||||
def set_quote(self, codes, start_date, end_date):
|
||||
def set_quote(self, codes, start_time, end_time):
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
if self.quote[self.deal_price].isna().any():
|
||||
@@ -146,35 +146,37 @@ class Exchange:
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
# update quote: pd.DataFrame to dict, for search use
|
||||
self.quote = quote_df.to_dict("index")
|
||||
self.quote = quote_df
|
||||
|
||||
def _update_limit(self, buy_limit, sell_limit):
|
||||
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
|
||||
|
||||
def check_stock_limit(self, stock_id, trade_date):
|
||||
def check_stock_limit(self, stock_id, start_time, end_time):
|
||||
"""Parameter
|
||||
stock_id
|
||||
trade_date
|
||||
is limtited
|
||||
"""
|
||||
return self.quote[(stock_id, trade_date)]["limit"]
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0]
|
||||
|
||||
|
||||
def check_stock_suspended(self, stock_id, trade_date):
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
return (stock_id, trade_date) not in self.quote
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time).empty
|
||||
|
||||
def is_stock_tradable(self, stock_id, trade_date):
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time):
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date):
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def check_order(self, order):
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit(
|
||||
order.stock_id, order.trade_date
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time
|
||||
):
|
||||
return False
|
||||
else:
|
||||
@@ -199,7 +201,7 @@ class Exchange:
|
||||
if trade_account is not None and position is not None:
|
||||
raise ValueError("trade_account and position can only choose one")
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position
|
||||
)
|
||||
@@ -214,24 +216,24 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, trade_date):
|
||||
return self.quote[(stock_id, trade_date)]
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time)
|
||||
|
||||
def get_close(self, stock_id, trade_date):
|
||||
return self.quote[(stock_id, trade_date)]["$close"]
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0]
|
||||
|
||||
def get_deal_price(self, stock_id, trade_date):
|
||||
deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0]
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, trade_date)
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, trade_date):
|
||||
return self.quote[(stock_id, trade_date)]["$factor"]
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0]
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date):
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
@@ -246,7 +248,7 @@ class Exchange:
|
||||
# calculate the total weight of tradable value
|
||||
tradable_weight = 0.0
|
||||
for stock_id in weight_position:
|
||||
if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
|
||||
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
# weight_position must be greater than 0 and less than 1
|
||||
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
|
||||
raise ValueError(
|
||||
@@ -260,12 +262,12 @@ class Exchange:
|
||||
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(stock_id=stock_id, trade_date=trade_date)
|
||||
// self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
@@ -292,7 +294,7 @@ class Exchange:
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, trade_date):
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
"""Parameter:
|
||||
target_position : dict { stock_id : amount }
|
||||
current_postion : dict { stock_id : amount}
|
||||
@@ -315,12 +317,12 @@ class Exchange:
|
||||
for stock_id in sorted_ids:
|
||||
|
||||
# Do not generate order for the nontradable stocks
|
||||
if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
|
||||
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
continue
|
||||
|
||||
target_amount = target_position.get(stock_id, 0)
|
||||
current_amount = current_position.get(stock_id, 0)
|
||||
factor = self.quote[(stock_id, trade_date)]["$factor"]
|
||||
factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
|
||||
|
||||
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
|
||||
if deal_amount == 0:
|
||||
@@ -332,7 +334,8 @@ class Exchange:
|
||||
stock_id=stock_id,
|
||||
amount=deal_amount,
|
||||
direction=Order.BUY,
|
||||
trade_date=trade_date,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
)
|
||||
@@ -343,14 +346,15 @@ class Exchange:
|
||||
stock_id=stock_id,
|
||||
amount=abs(deal_amount),
|
||||
direction=Order.SELL,
|
||||
trade_date=trade_date,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
)
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False):
|
||||
def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False):
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
@@ -358,10 +362,10 @@ class Exchange:
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False
|
||||
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
):
|
||||
value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id]
|
||||
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
|
||||
return value
|
||||
|
||||
def round_amount_by_trade_unit(self, deal_amount, factor):
|
||||
@@ -384,7 +388,7 @@ class Exchange:
|
||||
:return: trade_val, trade_cost
|
||||
"""
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
|
||||
15
qlib/contrib/backtest/interpreter.py
Normal file
15
qlib/contrib/backtest/interpreter.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
class BaseInterpreter:
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
class ActionInterpreter:
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
return action
|
||||
|
||||
class StateInterpreter:
|
||||
@staticmethod
|
||||
def interpret(state, **kwargs):
|
||||
return state
|
||||
@@ -7,7 +7,7 @@ class Order:
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
def __init__(self, stock_id, amount, trade_date, direction, factor):
|
||||
def __init__(self, stock_id, amount, start_time, end_time, direction, factor):
|
||||
"""Parameter
|
||||
direction : Order.SELL for sell; Order.BUY for buy
|
||||
stock_id : str
|
||||
@@ -24,6 +24,7 @@ class Order:
|
||||
self.amount = amount
|
||||
# amount of successfully completed orders
|
||||
self.deal_amount = 0
|
||||
self.trade_date = trade_date
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.direction = direction
|
||||
self.factor = factor
|
||||
|
||||
@@ -28,13 +28,13 @@ a typical example is :{
|
||||
class Position:
|
||||
"""Position"""
|
||||
|
||||
def __init__(self, cash=0, position_dict={}, today_account_value=0):
|
||||
def __init__(self, cash=0, position_dict={}, now_account_value=0):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
self.position["cash"] = cash
|
||||
self.position["today_account_value"] = today_account_value
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
def init_stock(self, stock_id, amount, price=None):
|
||||
self.position[stock_id] = {}
|
||||
@@ -82,7 +82,7 @@ class Position:
|
||||
# SELL
|
||||
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
else:
|
||||
raise NotImplementedError("do not suppotr order direction {}".format(order.direction))
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
def update_stock_price(self, stock_id, price):
|
||||
self.position[stock_id]["price"] = price
|
||||
@@ -109,7 +109,7 @@ class Position:
|
||||
return value
|
||||
|
||||
def get_stock_list(self):
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value"})
|
||||
return stock_list
|
||||
|
||||
def get_stock_price(self, code):
|
||||
@@ -163,16 +163,17 @@ class Position:
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def save_position(self, path, last_trade_date):
|
||||
def save_position(self, path, last_trade_time):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None
|
||||
cash["now_account_value"] = p["now_account_value"]
|
||||
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
|
||||
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
|
||||
del p["cash"]
|
||||
del p["today_account_value"]
|
||||
del p["now_account_value"]
|
||||
positions = pd.DataFrame.from_dict(p, orient="index")
|
||||
with pd.ExcelWriter(path) as writer:
|
||||
positions.to_excel(writer, sheet_name="position")
|
||||
@@ -189,10 +190,10 @@ class Position:
|
||||
'weight': <the security weight of total position value>,
|
||||
|
||||
sheet "cash"
|
||||
index: ['init_cash', 'cash', 'today_account_value']
|
||||
index: ['init_cash', 'cash', 'now_account_value']
|
||||
'init_cash': <inital cash when account was created>,
|
||||
'cash': <current cash in account>,
|
||||
'today_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
|
||||
'now_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
|
||||
@@ -200,14 +201,17 @@ class Position:
|
||||
positions = positions.to_dict(orient="index")
|
||||
init_cash = cash_record.loc["init_cash"].values[0]
|
||||
cash = cash_record.loc["cash"].values[0]
|
||||
today_account_value = cash_record.loc["today_account_value"].values[0]
|
||||
last_trade_date = cash_record.loc["last_trade_date"].values[0]
|
||||
now_account_value = cash_record.loc["now_account_value"].values[0]
|
||||
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
|
||||
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
|
||||
|
||||
# assign values
|
||||
self.position = {}
|
||||
self.init_cash = init_cash
|
||||
self.position = positions
|
||||
self.position["cash"] = cash
|
||||
self.position["today_account_value"] = today_account_value
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date)
|
||||
last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
|
||||
last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
|
||||
return last_trade_start_time, last_trade_end_time
|
||||
|
||||
@@ -21,20 +21,20 @@ class Report:
|
||||
self.costs = OrderedDict() # trade cost for each trade date
|
||||
self.values = OrderedDict() # value for each trade date
|
||||
self.cashes = OrderedDict()
|
||||
self.latest_report_date = None # pd.TimeStamp
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
return self.latest_report_date
|
||||
return self.latest_report_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
return self.accounts[self.latest_report_date]
|
||||
return self.accounts[self.latest_report_time]
|
||||
|
||||
def update_report_record(
|
||||
self,
|
||||
trade_date=None,
|
||||
trade_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
@@ -44,7 +44,7 @@ class Report:
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
trade_date,
|
||||
trade_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
@@ -56,14 +56,14 @@ class Report:
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
)
|
||||
# update report data
|
||||
self.accounts[trade_date] = account_value
|
||||
self.returns[trade_date] = return_rate
|
||||
self.turnovers[trade_date] = turnover_rate
|
||||
self.costs[trade_date] = cost_rate
|
||||
self.values[trade_date] = stock_value
|
||||
self.cashes[trade_date] = cash
|
||||
self.accounts[trade_time] = account_value
|
||||
self.returns[trade_time] = return_rate
|
||||
self.turnovers[trade_time] = turnover_rate
|
||||
self.costs[trade_time] = cost_rate
|
||||
self.values[trade_time] = stock_value
|
||||
self.cashes[trade_time] = cash
|
||||
# update latest_report_date
|
||||
self.latest_report_date = trade_date
|
||||
self.latest_report_time = trade_time
|
||||
# finish daily report update
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
@@ -74,7 +74,7 @@ class Report:
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
report.index.name = "date"
|
||||
report.index.name = "trade_time"
|
||||
return report
|
||||
|
||||
def save_report(self, path):
|
||||
@@ -94,13 +94,13 @@ class Report:
|
||||
|
||||
index = r.index
|
||||
self.init_vars()
|
||||
for date in index:
|
||||
for trade_time in index:
|
||||
self.update_report_record(
|
||||
trade_date=date,
|
||||
account_value=r.loc[date]["account"],
|
||||
cash=r.loc[date]["cash"],
|
||||
return_rate=r.loc[date]["return"],
|
||||
turnover_rate=r.loc[date]["turnover"],
|
||||
cost_rate=r.loc[date]["cost"],
|
||||
stock_value=r.loc[date]["value"],
|
||||
trade_time=trade_time,
|
||||
account_value=r.loc[trade_time]["account"],
|
||||
cash=r.loc[trade_time]["cash"],
|
||||
return_rate=r.loc[trade_time]["return"],
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
)
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
|
||||
from .dl_strategy import (
|
||||
TopkDropoutStrategy,
|
||||
BaseStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
from .rule_strategy import(
|
||||
TWAPStrategy,
|
||||
SBBEMAStrategy
|
||||
SBBStrategyBase,
|
||||
SBBStrategyEMA,
|
||||
)
|
||||
|
||||
from .cost_control import (
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .strategy import WeightStrategyBase
|
||||
from .dl_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils import sample_feature
|
||||
from ...strategy.base import DLStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ..backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(DLStrategy):
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
@@ -53,7 +53,7 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time)
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
@@ -67,9 +67,10 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
|
||||
def reset(trade_exchange=None, **kwargs):
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(TopkDropoutStrategy, self).reset(**kwargs)
|
||||
self.trade_exchange = trade_exchange
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index):
|
||||
"""get_risk_degree
|
||||
@@ -189,7 +190,7 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
# update cash
|
||||
cash += trade_val - trade_cost
|
||||
# sold
|
||||
del self.stock_count[code]
|
||||
self.stock_count[code] = 0
|
||||
else:
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
@@ -210,10 +211,10 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
# value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock suspended
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
|
||||
continue
|
||||
# buy order
|
||||
buy_price = self.trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
|
||||
buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
|
||||
buy_amount = value / buy_price
|
||||
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
|
||||
buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
|
||||
@@ -229,8 +230,8 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
self.stock_count[code] = 1
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
class WeightStrategyBase(DLStrategy):
|
||||
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs):
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs):
|
||||
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
|
||||
self.trade_exchange = trade_exchange
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
"""
|
||||
This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ..backtest.position import Position
|
||||
from ..backtest.exchange import Exchange
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
|
||||
@@ -4,18 +4,20 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils import sample_feature
|
||||
from ...data.data import D
|
||||
from ...strategy.base import RuleStrategy, TradingEnhancement
|
||||
from ...backtest.order import Order
|
||||
from ..backtest.order import Order
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_order_list:
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
@@ -43,13 +45,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
TREND_LONG = 2
|
||||
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
self.trade_amount = {}
|
||||
self.trade_delay = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
self.trade_trend[(order.stock_id, order.direction)] = TREND_MID
|
||||
super(SBBStrategyBase, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_order_list:
|
||||
self.trade_amount = {}
|
||||
self.trade_trend = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
@@ -64,7 +68,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
_pred_trend = self._pred_price_trend(order.stock_id)
|
||||
else:
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
if _pred_trend == TREND_MID:
|
||||
if _pred_trend == self.TREND_MID:
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=self.trade_amount[(order.stock_id, order.direction)],
|
||||
@@ -97,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
if self.trade_index % 2 == 1
|
||||
if self.trade_index % 2 == 1:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
return order_list
|
||||
@@ -110,8 +114,8 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
**kwargs,
|
||||
@@ -121,21 +125,23 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
|
||||
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
|
||||
fields = [("EMA($close, 10) - EMA($close, 20)", "signal")]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
if self.start_time and self.end_time:
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
self.signal.columns = ["signal"]
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")
|
||||
if _sample_signal.empty:
|
||||
return SBBStrategy.TREND_MID
|
||||
elif _sample_signal.iloc[0, 0] > 0:
|
||||
return SBBStrategy.TREND_LONG
|
||||
return self.TREND_MID
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
else:
|
||||
return SBBStrategy.TREND_SHORT
|
||||
return self.TREND_SHORT
|
||||
@@ -117,6 +117,7 @@ class CalendarProvider(abc.ABC):
|
||||
flag = f"{freq}_sam_{freq_sam}_future_{future}"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
return _calendar, _calendar_index
|
||||
else:
|
||||
flag_raw = f"{freq}_sam_{None}_future_{future}"
|
||||
if flag_raw in H["c"]:
|
||||
@@ -125,6 +126,7 @@ class CalendarProvider(abc.ABC):
|
||||
_calendar = np.array(self.load_calendar(freq, future))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag_raw] = _calendar, _calendar_index
|
||||
|
||||
if freq_sam is None:
|
||||
return _calendar, _calendar_index
|
||||
else:
|
||||
@@ -132,6 +134,7 @@ class CalendarProvider(abc.ABC):
|
||||
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
|
||||
H["c"][flag] = _calendar_sam, _calendar_sam_index
|
||||
return _calendar_sam, _calendar_sam_index
|
||||
|
||||
|
||||
def _uri(self, start_time, end_time, freq, future=False):
|
||||
"""Get the uri of calendar generation task."""
|
||||
@@ -541,8 +544,8 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
with open(fname) as f:
|
||||
return [pd.Timestamp(x.strip()) for x in f]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False, freq_sam=None):
|
||||
_calendar, _ = self._get_calendar(freq=freq, future=future)
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
_calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
|
||||
# strip
|
||||
if start_time:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
@@ -764,6 +767,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
self.conn = conn
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="calendar",
|
||||
request_content={
|
||||
|
||||
@@ -10,8 +10,9 @@ import pandas as pd
|
||||
|
||||
from ..utils import get_sample_freq_calendar
|
||||
from ..data.dataset import DatasetH
|
||||
from ..backtest.order import Order
|
||||
from ..backtest.env import TradeCalendarBase
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..contrib.backtest.order import Order
|
||||
from ..contrib.backtest.env import TradeCalendarBase
|
||||
|
||||
"""
|
||||
1. BaseStrategy 的粒度一定是数据粒度的整数倍
|
||||
@@ -24,26 +25,14 @@ class BaseStrategy(TradeCalendarBase):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
|
||||
def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs):
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time :
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
def _get_trade_time(self):
|
||||
if 0 < self.trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif self.trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
@@ -52,20 +41,26 @@ class BaseStrategy(TradeCalendarBase):
|
||||
class RuleStrategy(BaseStrategy):
|
||||
pass
|
||||
|
||||
class DLStrategy(BaseStrategy):
|
||||
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None):
|
||||
class ModelStrategy(BaseStrategy):
|
||||
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = self.model.predict(dataset)
|
||||
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
|
||||
#pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
|
||||
super(DLStrategy, self).__init__(step_bar, start_time, end_time)
|
||||
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
|
||||
def _update_model(self):
|
||||
def _convert_index_format(self, df):
|
||||
if get_level_index(df, level="datetime") == 0:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
def _update_model(self):
|
||||
"""update pred score
|
||||
"""
|
||||
pass
|
||||
|
||||
class TradingEnhancement:
|
||||
def reset(self, trade_order_list):
|
||||
self.trade_order_list = trade_order_list
|
||||
def reset(self, trade_order_list=None):
|
||||
if trade_order_list:
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import bisect
|
||||
import shutil
|
||||
import difflib
|
||||
import hashlib
|
||||
import warnings
|
||||
import datetime
|
||||
import requests
|
||||
import tempfile
|
||||
@@ -918,37 +919,40 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs):
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
|
||||
from ..data.data import Cal
|
||||
|
||||
try:
|
||||
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
|
||||
freq, freq_sam = freq, None
|
||||
except ValueError:
|
||||
freq_sam = freq
|
||||
if freq.endswith(("m", "month", "w", "week", "d", "day")):
|
||||
try:
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
except ValueError:
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs)
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
|
||||
freq = "day"
|
||||
elif freq.endswith(("min", "minute")):
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
|
||||
if instruments and type(instruments) is not list:
|
||||
if instruments and not isinstance(instruments, list):
|
||||
instruments = [instruments]
|
||||
if fields and type(fields) is not list:
|
||||
fields = [fields]
|
||||
selector_inst = slice(None) if instruments is None else instruments
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
if fields is not None and type(fields) is not list:
|
||||
fields = [fields]
|
||||
selector_fields = slice(None) if fields is None else fields
|
||||
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[(selector_inst, selector_datetime)]
|
||||
if fields:
|
||||
warnings.warn(f"sample series feature, {fields} is ignored!")
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
selector_fields = slice(None) if fields is None else fields
|
||||
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
|
||||
if method:
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user