mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
trade_account support multi bar report
This commit is contained in:
@@ -81,7 +81,7 @@ if __name__ == "__main__":
|
||||
backtest_config={
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.dl_strategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "week",
|
||||
"model": model,
|
||||
@@ -113,6 +113,18 @@ if __name__ == "__main__":
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"backtest":{
|
||||
"start_time": trade_start_time,
|
||||
"end_time": trade_end_time,
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ logger = get_module_logger("backtest caller")
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
@@ -72,6 +73,7 @@ def get_exchange(
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
exchange = Exchange(
|
||||
freq=freq,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
|
||||
|
||||
import copy
|
||||
import pandas as pd
|
||||
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
from ...utils import parse_freq, sample_feature
|
||||
|
||||
|
||||
|
||||
"""
|
||||
@@ -26,21 +29,86 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, last_trade_time=None):
|
||||
self.init_vars(init_cash, last_trade_time)
|
||||
def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
|
||||
self.init_vars(init_cash, benchmark, start_time, end_time)
|
||||
|
||||
def init_vars(self, init_cash, last_trade_time=None):
|
||||
def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
- benchmark: str/list/pd.Series
|
||||
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
|
||||
"""
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.benchmark = benchmark
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.freq = freq
|
||||
self.current = Position(cash=init_cash)
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.val = 0
|
||||
self.report = Report()
|
||||
self.earning = 0
|
||||
self.last_trade_time = last_trade_time
|
||||
self.report = Report()
|
||||
if freq and benchmark:
|
||||
self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None):
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
if freq is None:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
|
||||
except ValueError:
|
||||
_, norm_freq = parse_freq(freq)
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
|
||||
except ValueError:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def cal_change(x):
|
||||
return x.prod() - 1
|
||||
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
|
||||
def reset(self, benchmark=None, freq=None,**kwargs):
|
||||
if benchmark:
|
||||
self.benchmark = benchmark
|
||||
if freq:
|
||||
self.freq = freq
|
||||
if self.freq and self.benchmark and (freq or benchmark)
|
||||
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(k):
|
||||
setattr(k, v)
|
||||
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -83,7 +151,7 @@ class Account:
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
def update_report(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
@@ -100,20 +168,17 @@ class Account:
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
stock_list = self.current.get_stock_list()
|
||||
profit = 0
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
# update value
|
||||
self.val = self.current.calculate_value()
|
||||
# update earning (2nd view of return)
|
||||
# update earning
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
@@ -138,6 +203,7 @@ class Account:
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
)
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
@@ -148,23 +214,20 @@ class Account:
|
||||
|
||||
# finish today's updation
|
||||
# reset the daily variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.last_trade_time = (trade_start_time, trade_end_time)
|
||||
|
||||
def load_account(self, account_path):
|
||||
report = Report()
|
||||
position = Position()
|
||||
last_trade_time = position.load_position(account_path / "position.xlsx")
|
||||
report.load_report(account_path / "report.csv")
|
||||
position.load_position(account_path / "position.xlsx")
|
||||
|
||||
# assign values
|
||||
self.init_vars(position.init_cash)
|
||||
self.current = position
|
||||
self.report = report
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def save_account(self, account_path):
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
|
||||
self.current.save_position(account_path / "position.xlsx")
|
||||
self.report.save_report(account_path / "report.csv")
|
||||
|
||||
@@ -9,12 +9,26 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from ...data.data import Cal
|
||||
from ...utils import get_sample_freq_calendar
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
|
||||
|
||||
class TradeCalendarBase:
|
||||
|
||||
class BaseTradeCalendar:
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
**kwargs
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
if not start_time and not end_time:
|
||||
return
|
||||
if start_time:
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
if end_time:
|
||||
@@ -24,37 +38,33 @@ class TradeCalendarBase:
|
||||
self.calendar = _calendar
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
|
||||
_trade_calendar = self.calendar[_start_index: _end_index + 1]
|
||||
if _start_time != self.start_time:
|
||||
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
|
||||
self.start_index = _start_index - 1
|
||||
else:
|
||||
self.trade_calendar = np.hstack((_trade_calendar, self.end_time))
|
||||
self.start_index = _start_index
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_index = 0
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def _get_trade_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
if 0 < trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=1):
|
||||
def _get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
|
||||
|
||||
class BaseEnv(TradeCalendarBase):
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
class BaseEnv(BaseTradeCalendar):
|
||||
"""
|
||||
# Strategy framework document
|
||||
|
||||
@@ -67,38 +77,32 @@ class BaseEnv(TradeCalendarBase):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
update_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.generate_report = update_report
|
||||
self.verbose = verbose
|
||||
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
|
||||
|
||||
def _get_position(self):
|
||||
return self.trade_account.current
|
||||
super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
|
||||
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
super(BaseEnv, self).reset(**kwargs)
|
||||
if trade_account:
|
||||
self.trade_account = trade_account
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
|
||||
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self._get_position()}
|
||||
init_state = {"current": self.trade_account.current}
|
||||
return init_state
|
||||
|
||||
def execute(self, **kwargs):
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
def execute(self, order_list=None, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
def get_trade_account(self):
|
||||
raise NotImplementedError("get_trade_account is not implemented!")
|
||||
|
||||
def get_report(self):
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
|
||||
class SplitEnv(BaseEnv):
|
||||
def __init__(
|
||||
@@ -109,33 +113,44 @@ class SplitEnv(BaseEnv):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
update_report=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.sub_env = sub_env
|
||||
self.sub_strategy = sub_strategy
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose)
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs)
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
|
||||
if trade_account:
|
||||
self.sub_env.reset(trade_account=copy.copy(trade_account))
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
#if self.track:
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
super(SplitEnv, self).execute(**kwargs)
|
||||
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
|
||||
super(SplitEnv, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
#episode_reward += sub_reward
|
||||
_obs = {"current": self._get_position()}
|
||||
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {}
|
||||
return _obs, _info
|
||||
|
||||
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [(_report,_positions), *sub_env.get_report()]
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
|
||||
@@ -146,10 +161,11 @@ class SimulatorEnv(BaseEnv):
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
update_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs)
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs)
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(SimulatorEnv, self).reset(**kwargs)
|
||||
@@ -162,8 +178,8 @@ class SimulatorEnv(BaseEnv):
|
||||
"""
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
super(SimulatorEnv, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
@@ -197,7 +213,18 @@ class SimulatorEnv(BaseEnv):
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
_obs = {"current": self._get_position()}
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {"trade_info": trade_info}
|
||||
return _obs, _info
|
||||
return _obs, _info
|
||||
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [
|
||||
{
|
||||
"report": _report,
|
||||
"positions": _positions
|
||||
}
|
||||
]
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...config import C, REG_CN
|
||||
from ...utils import sample_feature
|
||||
from ...log import get_module_logger
|
||||
@@ -19,6 +20,7 @@ from .order import Order
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
@@ -55,6 +57,7 @@ class Exchange:
|
||||
target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
if trade_unit is None:
|
||||
@@ -105,7 +108,7 @@ class Exchange:
|
||||
def set_quote(self, codes, start_time, end_time):
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
if self.quote[self.deal_price].isna().any():
|
||||
@@ -146,7 +149,14 @@ class Exchange:
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
# update quote: pd.DataFrame to dict, for search use
|
||||
self.quote = quote_df
|
||||
if get_level_index(quote_df, level="datetime") == 1:
|
||||
quote_df = quote_df.swaplevel().sort_index()
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val
|
||||
|
||||
self.quote = quote_dict
|
||||
|
||||
def _update_limit(self, buy_limit, sell_limit):
|
||||
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
|
||||
@@ -157,13 +167,15 @@ class Exchange:
|
||||
trade_date
|
||||
is limtited
|
||||
"""
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0]
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
|
||||
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time).empty
|
||||
|
||||
if stock_id in self.quote:
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time):
|
||||
# check if stock can be traded
|
||||
@@ -217,13 +229,13 @@ class Exchange:
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time)
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0]
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0]
|
||||
deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0]
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
@@ -231,7 +243,7 @@ class Exchange:
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0]
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0]
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,6 @@ class Position:
|
||||
|
||||
def init_stock(self, stock_id, amount, price=None):
|
||||
self.position[stock_id] = {}
|
||||
self.position[stock_id]["count"] = 0 # update count in the end of this date
|
||||
self.position[stock_id]["amount"] = amount
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
@@ -87,8 +86,8 @@ class Position:
|
||||
def update_stock_price(self, stock_id, price):
|
||||
self.position[stock_id]["price"] = price
|
||||
|
||||
def update_stock_count(self, stock_id, count):
|
||||
self.position[stock_id]["count"] = count
|
||||
def update_stock_count(self, stock_id, bar, count):
|
||||
self.position[stock_id][f"count_{bar}"] = count
|
||||
|
||||
def update_stock_weight(self, stock_id, weight):
|
||||
self.position[stock_id]["weight"] = weight
|
||||
@@ -118,8 +117,11 @@ class Position:
|
||||
def get_stock_amount(self, code):
|
||||
return self.position[code]["amount"]
|
||||
|
||||
def get_stock_count(self, code):
|
||||
return self.position[code]["count"]
|
||||
def get_stock_count(self, code, bar):
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_stock_weight(self, code):
|
||||
return self.position[code]["weight"]
|
||||
@@ -153,25 +155,26 @@ class Position:
|
||||
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
|
||||
return d
|
||||
|
||||
def add_count_all(self):
|
||||
def add_count_all(self, bar):
|
||||
stock_list = self.get_stock_list()
|
||||
for code in stock_list:
|
||||
self.position[code]["count"] += 1
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
self.position[code][f"count_{bar}"] += 1
|
||||
else:
|
||||
self.position[code][f"count_{bar}"] = 1
|
||||
|
||||
def update_weight_all(self):
|
||||
weight_dict = self.get_stock_weight_dict()
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def save_position(self, path, last_trade_time):
|
||||
def save_position(self, path):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["now_account_value"] = p["now_account_value"]
|
||||
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
|
||||
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
|
||||
del p["cash"]
|
||||
del p["now_account_value"]
|
||||
positions = pd.DataFrame.from_dict(p, orient="index")
|
||||
@@ -183,8 +186,8 @@ class Position:
|
||||
"""load position information from a file
|
||||
should have format below
|
||||
sheet "position"
|
||||
columns: ['stock', 'count', 'amount', 'price', 'weight']
|
||||
'count': <how many days the security has been hold>,
|
||||
columns: ['stock', f'count_{bar}', 'amount', 'price', 'weight']
|
||||
f'count_{bar}': <how many bars the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
@@ -202,16 +205,9 @@ class Position:
|
||||
init_cash = cash_record.loc["init_cash"].values[0]
|
||||
cash = cash_record.loc["cash"].values[0]
|
||||
now_account_value = cash_record.loc["now_account_value"].values[0]
|
||||
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
|
||||
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
|
||||
|
||||
# assign values
|
||||
self.position = {}
|
||||
self.init_cash = init_cash
|
||||
self.position = positions
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
|
||||
last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
|
||||
return last_trade_start_time, last_trade_end_time
|
||||
|
||||
@@ -21,6 +21,7 @@ class Report:
|
||||
self.costs = OrderedDict() # trade cost for each trade date
|
||||
self.values = OrderedDict() # value for each trade date
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
|
||||
def is_empty(self):
|
||||
@@ -41,6 +42,7 @@ class Report:
|
||||
turnover_rate=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
@@ -51,9 +53,10 @@ class Report:
|
||||
turnover_rate,
|
||||
cost_rate,
|
||||
stock_value,
|
||||
bench_value
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
|
||||
)
|
||||
# update report data
|
||||
self.accounts[trade_time] = account_value
|
||||
@@ -62,6 +65,7 @@ class Report:
|
||||
self.costs[trade_time] = cost_rate
|
||||
self.values[trade_time] = stock_value
|
||||
self.cashes[trade_time] = cash
|
||||
self.benches[trade_time] = bench_value
|
||||
# update latest_report_date
|
||||
self.latest_report_time = trade_time
|
||||
# finish daily report update
|
||||
@@ -74,7 +78,8 @@ class Report:
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
report.index.name = "trade_time"
|
||||
report["bench"] = pd.Series(self.benches)
|
||||
report.index.name = "datetime"
|
||||
return report
|
||||
|
||||
def save_report(self, path):
|
||||
@@ -84,7 +89,7 @@ class Report:
|
||||
def load_report(self, path):
|
||||
"""load report from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
|
||||
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
@@ -103,4 +108,5 @@ class Report:
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
bench_value=r.loc[trade_time]["bench"]
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
for _trading_date, _value in position.items():
|
||||
# pd_date type: pd.Timestamp
|
||||
_cash = _value.pop("cash")
|
||||
for _item in ["today_account_value"]:
|
||||
for _item in ["now_account_value"]:
|
||||
if _item in _value:
|
||||
_value.pop(_item)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .dl_strategy import (
|
||||
from .model_strategy import (
|
||||
TopkDropoutStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .dl_strategy import WeightStrategyBase
|
||||
from .model_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
|
||||
@@ -81,10 +81,12 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return self.risk_degree
|
||||
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
|
||||
super(TopkDropoutStrategy, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
# It needs following actions to filter stocks
|
||||
@@ -168,7 +170,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
|
||||
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
|
||||
# can not sell this code
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
@@ -271,10 +273,12 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
# generate_order_list
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
super(WeightStrategyBase, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_pred_time()
|
||||
super(WeightStrategyBase, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
current_temp = copy.deepcopy(trade_account.current)
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
@@ -5,6 +5,7 @@ import pandas as pd
|
||||
|
||||
from ...utils import sample_feature
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...strategy.base import RuleStrategy, TradingEnhancement
|
||||
from ..backtest.order import Order
|
||||
|
||||
@@ -21,8 +22,8 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
super(TopkDropoutStrategy, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
_order = Order(
|
||||
@@ -59,8 +60,8 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(SBBStrategyBase, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
super(SBBStrategyBase, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
@@ -127,21 +128,33 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
|
||||
|
||||
def _convert_index_format(self, df):
|
||||
if get_level_index(df, level="datetime") == 1:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
if self.start_time and self.end_time:
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
self.signal.columns = ["signal"]
|
||||
|
||||
signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
signal_df = self._convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")
|
||||
if _sample_signal.empty:
|
||||
if stock_id not in self.signal:
|
||||
return self.TREND_MID
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
_sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last")
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
return self.TREND_MID
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..utils import get_sample_freq_calendar
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..contrib.backtest.order import Order
|
||||
from ..contrib.backtest.env import TradeCalendarBase
|
||||
from ..contrib.backtest.env import BaseTradeCalendar
|
||||
|
||||
"""
|
||||
1. BaseStrategy 的粒度一定是数据粒度的整数倍
|
||||
@@ -20,22 +20,10 @@ from ..contrib.backtest.env import TradeCalendarBase
|
||||
- adjust_dates这个东西啥用
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
class BaseStrategy(TradeCalendarBase):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time :
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
class BaseStrategy(BaseTradeCalendar):
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
raise NotImplementedError("generator_order_list is not implemented!")
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
@@ -50,14 +38,14 @@ class ModelStrategy(BaseStrategy):
|
||||
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
|
||||
def _convert_index_format(self, df):
|
||||
if get_level_index(df, level="datetime") == 0:
|
||||
if get_level_index(df, level="datetime") == 1:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
def _update_model(self):
|
||||
"""update pred score
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
class TradingEnhancement:
|
||||
def reset(self, trade_order_list=None):
|
||||
|
||||
@@ -861,15 +861,38 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def parse_freq(freq):
|
||||
freq = freq.lower()
|
||||
search_obj =re.search("^([0-9]*)([a-z]+)", freq)
|
||||
if search_obj is None:
|
||||
raise ValueError("freq format is not supported")
|
||||
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
|
||||
_freq = search_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": "month",
|
||||
"mon": "month",
|
||||
"week": "week",
|
||||
"w": "week",
|
||||
"day": "day",
|
||||
"d": "day",
|
||||
"minute": "minute",
|
||||
"min": "minute",
|
||||
}
|
||||
try:
|
||||
_freq = _freq_format_dict.get(_freq)
|
||||
except KeyError:
|
||||
raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min")
|
||||
return _count, _freq
|
||||
|
||||
def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
"""
|
||||
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
raw_count, freq_raw = parse_freq(freq_raw)
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
if freq_sam == "minute":
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
@@ -888,38 +911,36 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
|
||||
if not freq_raw.endswith(("minute", "min")):
|
||||
|
||||
if req_raw != "minute":
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
|
||||
if raw_minutes > sam_minutes:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw)))
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)))
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
return _calendar_minute
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam.endswith(("day", "d")):
|
||||
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
|
||||
return _calendar_day[::sam_days]
|
||||
if freq_sam == "day":
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam.endswith(("week", "w")):
|
||||
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
|
||||
elif freq_sam == "week":
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_weeks]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam.endswith(("month", "m")):
|
||||
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
|
||||
elif freq_sam == "month":
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_months]
|
||||
return _calendar_month[::sam_count]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
|
||||
_, norm_freq = parse_freq(freq)
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
try:
|
||||
@@ -927,34 +948,47 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg
|
||||
freq, freq_sam = freq, None
|
||||
except ValueError:
|
||||
freq_sam = freq
|
||||
if freq.endswith(("m", "month", "w", "week", "d", "day")):
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
except ValueError:
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
|
||||
freq = "day"
|
||||
elif freq.endswith(("min", "minute")):
|
||||
except ValueError:
|
||||
raise
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
elif norm_freq == "minute":
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
|
||||
if instruments and not isinstance(instruments, list):
|
||||
instruments = [instruments]
|
||||
selector_inst = slice(None) if instruments is None else instruments
|
||||
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[(selector_inst, selector_datetime)]
|
||||
if fields:
|
||||
warnings.warn(f"sample series feature, {fields} is ignored!")
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
selector_fields = slice(None) if fields is None else fields
|
||||
feature = feature.loc[(selector_inst, selector_datetime), selector_fields]
|
||||
if method:
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
return feature
|
||||
fields = fields if fields else slice(None)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields]
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs))
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return method_func(feature, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
@@ -233,8 +233,8 @@ class PortAnaRecord(SignalRecord):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.env_config = config["env"]
|
||||
self.backtest_config = config["backtest"]
|
||||
self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
@@ -244,36 +244,32 @@ class PortAnaRecord(SignalRecord):
|
||||
super().generate()
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load("pred.pkl")
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
# save portfolio analysis results
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
# print out results
|
||||
pprint("The following are analysis results of the excess return without cost.")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint("The following are analysis results of the excess return with cost.")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
|
||||
for report_id, (report_normal, positions_normal) in enumerate(report_list):
|
||||
if report_dict is None:
|
||||
continue
|
||||
|
||||
self.recorder.save_objects(**{f"report_normal_{report_id}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{f"positions_norma_{report_id}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
self.recorder.save_objects(**{f"port_analysis.pkl_{report_id}": analysis_df}, artifact_path=PortAnaRecord.get_path())
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{report_id}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
# print out results
|
||||
pprint("The following are analysis results of the excess return without cost.")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint("The following are analysis results of the excess return with cost.")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
def list(self):
|
||||
return [
|
||||
|
||||
Reference in New Issue
Block a user