1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

trade_account support multi bar report

This commit is contained in:
bxdd
2021-04-29 02:15:34 +08:00
parent 8920c1967f
commit 86a6f565e8
15 changed files with 362 additions and 209 deletions

View File

@@ -81,7 +81,7 @@ if __name__ == "__main__":
backtest_config={
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.dl_strategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"step_bar": "week",
"model": model,
@@ -113,6 +113,18 @@ if __name__ == "__main__":
}
}
}
},
"backtest":{
"start_time": trade_start_time,
"end_time": trade_end_time,
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
}

View File

@@ -19,6 +19,7 @@ logger = get_module_logger("backtest caller")
def get_exchange(
exchange=None,
freq="day",
start_time=None,
end_time=None,
codes = "all",
@@ -72,6 +73,7 @@ def get_exchange(
deal_price = "$" + deal_price
exchange = Exchange(
freq=freq,
start_time=start_time,
end_time=end_time,
codes=codes,

View File

@@ -3,10 +3,13 @@
import copy
import pandas as pd
from .position import Position
from .report import Report
from .order import Order
from ...utils import parse_freq, sample_feature
"""
@@ -26,21 +29,86 @@ rtn & earning in the Account
class Account:
def __init__(self, init_cash, last_trade_time=None):
self.init_vars(init_cash, last_trade_time)
def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
self.init_vars(init_cash, benchmark, start_time, end_time)
def init_vars(self, init_cash, last_trade_time=None):
def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
"""
Parameters
----------
- benchmark: str/list/pd.Series
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
# init cash
self.init_cash = init_cash
self.benchmark = benchmark
self.start_time = start_time
self.end_time = end_time
self.freq = freq
self.current = Position(cash=init_cash)
self.positions = {}
self.rtn = 0
self.ct = 0
self.to = 0
self.val = 0
self.report = Report()
self.earning = 0
self.last_trade_time = last_trade_time
self.report = Report()
if freq and benchmark:
self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq)
def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None):
if isinstance(benchmark, pd.Series):
return benchmark
else:
if freq is None:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
fields = ["$close/Ref($close,1)-1"]
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
except ValueError:
_, norm_freq = parse_freq(freq)
if norm_freq in ["month", "week", "day"]:
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
except ValueError:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
else:
raise ValueError(f"benchmark freq {freq} is not supported")
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def cal_change(x):
return x.prod() - 1
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
def reset(self, benchmark=None, freq=None,**kwargs):
if benchmark:
self.benchmark = benchmark
if freq:
self.freq = freq
if self.freq and self.benchmark and (freq or benchmark)
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
for k, v in kwargs:
if hasattr(k):
setattr(k, v)
def get_positions(self):
return self.positions
@@ -83,7 +151,7 @@ class Account:
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
def update_report(self, trade_start_time, trade_end_time, trade_exchange):
"""
start_time: pd.TimeStamp
end_time: pd.TimeStamp
@@ -100,20 +168,17 @@ class Account:
"""
# update price for stock in the position and the profit from changed_price
stock_list = self.current.get_stock_list()
profit = 0
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=bar_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()
self.current.add_count_all(bar=self.freq)
# update value
self.val = self.current.calculate_value()
# update earning (2nd view of return)
# update earning
# account_value - last_account_value
# for the first trade date, account_value - init_cash
# self.report.is_empty() to judge is_first_trade_date
@@ -138,6 +203,7 @@ class Account:
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=now_stock_value,
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
)
# set now_account_value to position
self.current.position["now_account_value"] = now_account_value
@@ -148,23 +214,20 @@ class Account:
# finish today's updation
# reset the daily variables
self.rtn = 0
self.ct = 0
self.to = 0
self.last_trade_time = (trade_start_time, trade_end_time)
def load_account(self, account_path):
report = Report()
position = Position()
last_trade_time = position.load_position(account_path / "position.xlsx")
report.load_report(account_path / "report.csv")
position.load_position(account_path / "position.xlsx")
# assign values
self.init_vars(position.init_cash)
self.current = position
self.report = report
self.last_trade_time = last_trade_time
def save_account(self, account_path):
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
self.current.save_position(account_path / "position.xlsx")
self.report.save_report(account_path / "report.csv")

View File

@@ -9,12 +9,26 @@ import numpy as np
import pandas as pd
from ...data.data import Cal
from ...utils import get_sample_freq_calendar
from .position import Position
from .report import Report
from .order import Order
class TradeCalendarBase:
class BaseTradeCalendar:
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
**kwargs
):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time)
def _reset_trade_calendar(self, start_time, end_time):
if not start_time and not end_time:
return
if start_time:
self.start_time = pd.Timestamp(start_time)
if end_time:
@@ -24,37 +38,33 @@ class TradeCalendarBase:
self.calendar = _calendar
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
_trade_calendar = self.calendar[_start_index: _end_index + 1]
if _start_time != self.start_time:
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
self.start_index = _start_index - 1
else:
self.trade_calendar = np.hstack((_trade_calendar, self.end_time))
self.start_index = _start_index
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
self.trade_index = 0
self.trade_len = len(self.trade_calendar)
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def _get_trade_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
if 0 < trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1)
return trade_start_time, trade_end_time
elif trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
trade_end_time = self.trade_calendar[trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def reset(self, start_time=None, end_time=None, **kwargs):
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def _get_calendar_time(self, trade_index=1, shift=1):
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
class BaseEnv(TradeCalendarBase):
def finished(self):
return self.trade_index >= self.trade_len
def step(self):
self.trade_index = self.trade_index + 1
class BaseEnv(BaseTradeCalendar):
"""
# Strategy framework document
@@ -67,38 +77,32 @@ class BaseEnv(TradeCalendarBase):
start_time=None,
end_time=None,
trade_account=None,
update_report=False,
verbose=False,
**kwargs,
):
self.step_bar = step_bar
self.generate_report = update_report
self.verbose = verbose
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
def _get_position(self):
return self.trade_account.current
super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
def reset(self, trade_account=None, **kwargs):
super(BaseEnv, self).reset(**kwargs)
if trade_account:
self.trade_account = trade_account
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
def get_init_state(self):
init_state = {"current": self._get_position()}
init_state = {"current": self.trade_account.current}
return init_state
def execute(self, **kwargs):
raise NotImplementedError("execute is not implemented!")
def execute(self, order_list=None, **kwargs):
self.trade_index = self.trade_index + 1
def finished(self):
return self.trade_index >= self.trade_len - 1
def get_trade_account(self):
raise NotImplementedError("get_trade_account is not implemented!")
def get_report(self):
raise NotImplementedError("get_report is not implemented!")
class SplitEnv(BaseEnv):
def __init__(
@@ -109,33 +113,44 @@ class SplitEnv(BaseEnv):
start_time=None,
end_time=None,
trade_account=None,
update_report=False,
verbose=False,
**kwargs
):
self.sub_env = sub_env
self.sub_strategy = sub_strategy
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose)
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs)
def reset(self, trade_account=None, **kwargs):
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
if trade_account:
self.sub_env.reset(trade_account=copy.copy(trade_account))
def execute(self, order_list, **kwargs):
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
#if self.track:
# yield action
#episode_reward = 0
super(SplitEnv, self).execute(**kwargs)
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
super(SplitEnv, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
trade_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
#episode_reward += sub_reward
_obs = {"current": self._get_position()}
if self.generate_report:
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
_obs = {"current": self.trade_account.current}
_info = {}
return _obs, _info
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [(_report,_positions), *sub_env.get_report()]
class SimulatorEnv(BaseEnv):
@@ -146,10 +161,11 @@ class SimulatorEnv(BaseEnv):
end_time=None,
trade_account=None,
trade_exchange=None,
update_report=False,
verbose=False,
**kwargs,
):
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs)
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs)
def reset(self, trade_exchange=None, **kwargs):
super(SimulatorEnv, self).reset(**kwargs)
@@ -162,8 +178,8 @@ class SimulatorEnv(BaseEnv):
"""
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
super(SimulatorEnv, self).execute(**kwargs)
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
super(SimulatorEnv, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
trade_info = []
for order in order_list:
if self.trade_exchange.check_order(order) is True:
@@ -197,7 +213,18 @@ class SimulatorEnv(BaseEnv):
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
# do nothing
pass
self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
_obs = {"current": self._get_position()}
if self.generate_report:
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
_obs = {"current": self.trade_account.current}
_info = {"trade_info": trade_info}
return _obs, _info
return _obs, _info
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [
{
"report": _report,
"positions": _positions
}
]

View File

@@ -9,6 +9,7 @@ import numpy as np
import pandas as pd
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...config import C, REG_CN
from ...utils import sample_feature
from ...log import get_module_logger
@@ -19,6 +20,7 @@ from .order import Order
class Exchange:
def __init__(
self,
freq="day",
start_time=None,
end_time=None,
codes="all",
@@ -55,6 +57,7 @@ class Exchange:
target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
self.freq = freq
self.start_time = start_time
self.end_time = end_time
if trade_unit is None:
@@ -105,7 +108,7 @@ class Exchange:
def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"])
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
@@ -146,7 +149,14 @@ class Exchange:
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
self.quote = quote_df
if get_level_index(quote_df, level="datetime") == 1:
quote_df = quote_df.swaplevel().sort_index()
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val
self.quote = quote_dict
def _update_limit(self, buy_limit, sell_limit):
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
@@ -157,13 +167,15 @@ class Exchange:
trade_date
is limtited
"""
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0]
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
return sample_feature(self.quote, stock_id, start_time, end_time).empty
if stock_id in self.quote:
return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None
else:
return True
def is_stock_tradable(self, stock_id, start_time, end_time):
# check if stock can be traded
@@ -217,13 +229,13 @@ class Exchange:
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time)
return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
def get_close(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0]
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0]
deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
@@ -231,7 +243,7 @@ class Exchange:
return deal_price
def get_factor(self, stock_id, start_time, end_time):
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0]
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0]
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
"""

View File

@@ -38,7 +38,6 @@ class Position:
def init_stock(self, stock_id, amount, price=None):
self.position[stock_id] = {}
self.position[stock_id]["count"] = 0 # update count in the end of this date
self.position[stock_id]["amount"] = amount
self.position[stock_id]["price"] = price
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
@@ -87,8 +86,8 @@ class Position:
def update_stock_price(self, stock_id, price):
self.position[stock_id]["price"] = price
def update_stock_count(self, stock_id, count):
self.position[stock_id]["count"] = count
def update_stock_count(self, stock_id, bar, count):
self.position[stock_id][f"count_{bar}"] = count
def update_stock_weight(self, stock_id, weight):
self.position[stock_id]["weight"] = weight
@@ -118,8 +117,11 @@ class Position:
def get_stock_amount(self, code):
return self.position[code]["amount"]
def get_stock_count(self, code):
return self.position[code]["count"]
def get_stock_count(self, code, bar):
if f"count_{bar}" in self.position[code]:
return self.position[code][f"count_{bar}"]
else:
return 0
def get_stock_weight(self, code):
return self.position[code]["weight"]
@@ -153,25 +155,26 @@ class Position:
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
return d
def add_count_all(self):
def add_count_all(self, bar):
stock_list = self.get_stock_list()
for code in stock_list:
self.position[code]["count"] += 1
if f"count_{bar}" in self.position[code]:
self.position[code][f"count_{bar}"] += 1
else:
self.position[code][f"count_{bar}"] = 1
def update_weight_all(self):
weight_dict = self.get_stock_weight_dict()
for stock_code, weight in weight_dict.items():
self.update_stock_weight(stock_code, weight)
def save_position(self, path, last_trade_time):
def save_position(self, path):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["now_account_value"] = p["now_account_value"]
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
del p["cash"]
del p["now_account_value"]
positions = pd.DataFrame.from_dict(p, orient="index")
@@ -183,8 +186,8 @@ class Position:
"""load position information from a file
should have format below
sheet "position"
columns: ['stock', 'count', 'amount', 'price', 'weight']
'count': <how many days the security has been hold>,
columns: ['stock', f'count_{bar}', 'amount', 'price', 'weight']
f'count_{bar}': <how many bars the security has been hold>,
'amount': <the amount of the security>,
'price': <the close price of security in the last trading day>,
'weight': <the security weight of total position value>,
@@ -202,16 +205,9 @@ class Position:
init_cash = cash_record.loc["init_cash"].values[0]
cash = cash_record.loc["cash"].values[0]
now_account_value = cash_record.loc["now_account_value"].values[0]
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
# assign values
self.position = {}
self.init_cash = init_cash
self.position = positions
self.position["cash"] = cash
self.position["now_account_value"] = now_account_value
last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
return last_trade_start_time, last_trade_end_time

View File

@@ -21,6 +21,7 @@ class Report:
self.costs = OrderedDict() # trade cost for each trade date
self.values = OrderedDict() # value for each trade date
self.cashes = OrderedDict()
self.benches = OrderedDict()
self.latest_report_time = None # pd.TimeStamp
def is_empty(self):
@@ -41,6 +42,7 @@ class Report:
turnover_rate=None,
cost_rate=None,
stock_value=None,
bench_value=None,
):
# check data
if None in [
@@ -51,9 +53,10 @@ class Report:
turnover_rate,
cost_rate,
stock_value,
bench_value
]:
raise ValueError(
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
)
# update report data
self.accounts[trade_time] = account_value
@@ -62,6 +65,7 @@ class Report:
self.costs[trade_time] = cost_rate
self.values[trade_time] = stock_value
self.cashes[trade_time] = cash
self.benches[trade_time] = bench_value
# update latest_report_date
self.latest_report_time = trade_time
# finish daily report update
@@ -74,7 +78,8 @@ class Report:
report["cost"] = pd.Series(self.costs)
report["value"] = pd.Series(self.values)
report["cash"] = pd.Series(self.cashes)
report.index.name = "trade_time"
report["bench"] = pd.Series(self.benches)
report.index.name = "datetime"
return report
def save_report(self, path):
@@ -84,7 +89,7 @@ class Report:
def load_report(self, path):
"""load report from a file
should have format like
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash', 'bench']
:param
path: str/ pathlib.Path()
"""
@@ -103,4 +108,5 @@ class Report:
turnover_rate=r.loc[trade_time]["turnover"],
cost_rate=r.loc[trade_time]["cost"],
stock_value=r.loc[trade_time]["value"],
bench_value=r.loc[trade_time]["bench"]
)

View File

@@ -41,7 +41,7 @@ def parse_position(position: dict = None) -> pd.DataFrame:
for _trading_date, _value in position.items():
# pd_date type: pd.Timestamp
_cash = _value.pop("cash")
for _item in ["today_account_value"]:
for _item in ["now_account_value"]:
if _item in _value:
_value.pop(_item)

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .dl_strategy import (
from .model_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
)

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .dl_strategy import WeightStrategyBase
from .model_strategy import WeightStrategyBase
import copy

View File

@@ -81,10 +81,12 @@ class TopkDropoutStrategy(ModelStrategy):
return self.risk_degree
def generate_order_list(self, current, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
super(TopkDropoutStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
if self.only_tradable:
# If The strategy only consider tradable stock when make decision
# It needs following actions to filter stocks
@@ -168,7 +170,7 @@ class TopkDropoutStrategy(ModelStrategy):
continue
if code in sell:
# check hold limit
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
# can not sell this code
# no buy signal, but the stock is kept
self.stock_count[code] += 1
@@ -271,10 +273,12 @@ class WeightStrategyBase(ModelStrategy):
"""
# generate_order_list
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
super(WeightStrategyBase, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
pred_start_time, pred_end_time = self._get_pred_time()
super(WeightStrategyBase, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
current_temp = copy.deepcopy(trade_account.current)
target_weight_position = self.generate_target_weight_position(
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time

View File

@@ -5,6 +5,7 @@ import pandas as pd
from ...utils import sample_feature
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...strategy.base import RuleStrategy, TradingEnhancement
from ..backtest.order import Order
@@ -21,8 +22,8 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
def generate_order_list(self, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
super(TopkDropoutStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
order_list = []
for order in self.trade_order_list:
_order = Order(
@@ -59,8 +60,8 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_order_list(self, **kwargs):
super(SBBStrategyBase, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
super(SBBStrategyBase, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
order_list = []
for order in self.trade_order_list:
@@ -127,21 +128,33 @@ class SBBStrategyEMA(SBBStrategyBase):
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
def _convert_index_format(self, df):
if get_level_index(df, level="datetime") == 1:
df = df.swaplevel().sort_index()
return df
def _reset_trade_calendar(self, start_time=None, end_time=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
if self.start_time and self.end_time:
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
self.signal.columns = ["signal"]
signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
signal_df = self._convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")
if _sample_signal.empty:
if stock_id not in self.signal:
return self.TREND_MID
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
else:
return self.TREND_SHORT
_sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last")
if _sample_signal is None or _sample_signal.iloc[0] == 0:
return self.TREND_MID
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
else:
return self.TREND_SHORT

View File

@@ -12,7 +12,7 @@ from ..utils import get_sample_freq_calendar
from ..data.dataset import DatasetH
from ..data.dataset.utils import get_level_index
from ..contrib.backtest.order import Order
from ..contrib.backtest.env import TradeCalendarBase
from ..contrib.backtest.env import BaseTradeCalendar
"""
1. BaseStrategy 的粒度一定是数据粒度的整数倍
@@ -20,22 +20,10 @@ from ..contrib.backtest.env import TradeCalendarBase
- adjust_dates这个东西啥用
- label和freq和strategy的bar分离这个如何决策呢
"""
class BaseStrategy(TradeCalendarBase):
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def reset(self, start_time=None, end_time=None, **kwargs):
if start_time or end_time :
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
class BaseStrategy(BaseTradeCalendar):
def generate_order_list(self, **kwargs):
self.trade_index = self.trade_index + 1
raise NotImplementedError("generator_order_list is not implemented!")
class RuleStrategy(BaseStrategy):
@@ -50,14 +38,14 @@ class ModelStrategy(BaseStrategy):
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
def _convert_index_format(self, df):
if get_level_index(df, level="datetime") == 0:
if get_level_index(df, level="datetime") == 1:
df = df.swaplevel().sort_index()
return df
def _update_model(self):
"""update pred score
"""
pass
raise NotImplementedError("_update_model is not implemented!")
class TradingEnhancement:
def reset(self, trade_order_list=None):

View File

@@ -861,15 +861,38 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def parse_freq(freq):
freq = freq.lower()
search_obj =re.search("^([0-9]*)([a-z]+)", freq)
if search_obj is None:
raise ValueError("freq format is not supported")
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
_freq = search_obj.group(2)
_freq_format_dict = {
"month": "month",
"mon": "month",
"week": "week",
"w": "week",
"day": "day",
"d": "day",
"minute": "minute",
"min": "minute",
}
try:
_freq = _freq_format_dict.get(_freq)
except KeyError:
raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min")
return _count, _freq
def sample_calendar(calendar_raw, freq_raw, freq_sam):
"""
freq_raw : "min" or "day"
"""
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
raw_count, freq_raw = parse_freq(freq_raw)
sam_count, freq_sam = parse_freq(freq_sam)
if not len(calendar_raw):
return calendar_raw
if freq_sam.endswith(("minute", "min")):
if freq_sam == "minute":
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
@@ -888,38 +911,36 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
else:
raise ValueError("calendar minute_index error")
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
if not freq_raw.endswith(("minute", "min")):
if req_raw != "minute":
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
else:
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
if raw_minutes > sam_minutes:
if raw_count > sam_count:
raise ValueError("raw freq must be higher than sample freq")
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw)))
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)))
if calendar_raw[0] > _calendar_minute[0]:
_calendar_minute[0] = calendar_raw[0]
return _calendar_minute
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
if freq_sam.endswith(("day", "d")):
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
return _calendar_day[::sam_days]
if freq_sam == "day":
return _calendar_day[::sam_count]
elif freq_sam.endswith(("week", "w")):
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
elif freq_sam == "week":
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
return _calendar_week[::sam_weeks]
return _calendar_week[::sam_count]
elif freq_sam.endswith(("month", "m")):
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
elif freq_sam == "month":
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
return _calendar_month[::sam_months]
return _calendar_month[::sam_count]
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
_, norm_freq = parse_freq(freq)
from ..data.data import Cal
try:
@@ -927,34 +948,47 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg
freq, freq_sam = freq, None
except ValueError:
freq_sam = freq
if freq.endswith(("m", "month", "w", "week", "d", "day")):
if norm_freq in ["month", "week", "day"]:
try:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
except ValueError:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
freq = "day"
elif freq.endswith(("min", "minute")):
except ValueError:
raise
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
elif norm_freq == "minute":
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
if instruments and not isinstance(instruments, list):
instruments = [instruments]
selector_inst = slice(None) if instruments is None else instruments
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
selector_datetime = slice(start_time, end_time)
if isinstance(feature, pd.Series):
feature = feature.loc[(selector_inst, selector_datetime)]
if fields:
warnings.warn(f"sample series feature, {fields} is ignored!")
elif isinstance(feature, pd.DataFrame):
selector_fields = slice(None) if fields is None else fields
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
if method:
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:
return feature
fields = fields if fields else slice(None)
from ..data.dataset.utils import get_level_index
datetime_level = get_level_index(feature, level="datetime") == 0
if isinstance(feature, pd.Series):
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
elif isinstance(feature, pd.DataFrame):
feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields]
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):
if callable(method):
method_func = method
return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs))
elif isinstance(method, str):
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:
if callable(method):
method_func = method
return method_func(feature, **method_kwargs)
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature

View File

@@ -233,8 +233,8 @@ class PortAnaRecord(SignalRecord):
super().__init__(recorder=recorder, **kwargs)
self.strategy_config = config["strategy"]
self.env_config = config["env"]
self.backtest_config = config["backtest"]
self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
def generate(self, **kwargs):
# check previously stored prediction results
@@ -244,36 +244,32 @@ class PortAnaRecord(SignalRecord):
super().generate()
# custom strategy and get backtest
pred_score = super().load("pred.pkl")
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
order_normal = report_dict.get("order_list")
if order_normal:
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
# save portfolio analysis results
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
logger.info(
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
# print out results
pprint("The following are analysis results of the excess return without cost.")
pprint(analysis["excess_return_without_cost"])
pprint("The following are analysis results of the excess return with cost.")
pprint(analysis["excess_return_with_cost"])
report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
for report_id, (report_normal, positions_normal) in enumerate(report_list):
if report_dict is None:
continue
self.recorder.save_objects(**{f"report_normal_{report_id}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{f"positions_norma_{report_id}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{f"port_analysis.pkl_{report_id}": analysis_df}, artifact_path=PortAnaRecord.get_path())
logger.info(
f"Portfolio analysis record 'port_analysis_{report_id}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
# print out results
pprint("The following are analysis results of the excess return without cost.")
pprint(analysis["excess_return_without_cost"])
pprint("The following are analysis results of the excess return with cost.")
pprint(analysis["excess_return_with_cost"])
def list(self):
return [