mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
black format
This commit is contained in:
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
},
|
||||
}
|
||||
# model initialization
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
@@ -78,7 +78,7 @@ if __name__ == "__main__":
|
||||
trade_start_time = "2017-01-31"
|
||||
trade_end_time = "2018-01-31"
|
||||
|
||||
backtest_config={
|
||||
backtest_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
@@ -90,7 +90,7 @@ if __name__ == "__main__":
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"env":{
|
||||
"env": {
|
||||
"class": "SplitEnv",
|
||||
"module_path": "qlib.contrib.backtest.env",
|
||||
"kwargs": {
|
||||
@@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"verbose": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
"sub_strategy": {
|
||||
"class": "SBBStrategyEMA",
|
||||
@@ -110,11 +110,11 @@ if __name__ == "__main__":
|
||||
"step_bar": "day",
|
||||
"freq": "day",
|
||||
"instruments": "csi300",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"backtest":{
|
||||
"backtest": {
|
||||
"start_time": trade_start_time,
|
||||
"end_time": trade_end_time,
|
||||
"verbose": False,
|
||||
@@ -125,8 +125,14 @@ if __name__ == "__main__":
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False)
|
||||
report_dict = backtest(
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
**backtest_config,
|
||||
account=1e8,
|
||||
deal_price="$close",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_exchange(
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
codes="all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
@@ -89,6 +89,7 @@ def get_exchange(
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
@@ -103,6 +104,7 @@ def init_env_instance_by_config(env):
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
@@ -114,8 +116,8 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
setup_exchange(root_instance.sub_env, trade_exchange)
|
||||
if hasattr(root_instance, "sub_strategy"):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
|
||||
@@ -11,7 +11,6 @@ from .order import Order
|
||||
from ...utils import parse_freq, sample_feature
|
||||
|
||||
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
rtn:
|
||||
@@ -87,7 +86,7 @@ class Account:
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
@@ -95,20 +94,20 @@ class Account:
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def cal_change(x):
|
||||
return x.prod() - 1
|
||||
|
||||
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
|
||||
def reset(self, benchmark=None, freq=None,**kwargs):
|
||||
def reset(self, benchmark=None, freq=None, **kwargs):
|
||||
if benchmark:
|
||||
self.benchmark = benchmark
|
||||
if freq:
|
||||
self.freq = freq
|
||||
if self.freq and self.benchmark and (freq or benchmark)
|
||||
if self.freq and self.benchmark and (freq or benchmark):
|
||||
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(k):
|
||||
setattr(k, v)
|
||||
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -203,7 +202,7 @@ class Account:
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time),
|
||||
)
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
|
||||
@@ -7,6 +7,7 @@ import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
|
||||
|
||||
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
|
||||
|
||||
trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time)
|
||||
@@ -17,10 +18,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
|
||||
while not trade_env.finished():
|
||||
_order_list = trade_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = trade_env.execute(_order_list)
|
||||
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
positions = trade_account.get_positions()
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
|
||||
return report_dict
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
@@ -14,15 +12,8 @@ from .report import Report
|
||||
from .order import Order
|
||||
|
||||
|
||||
|
||||
class BaseTradeCalendar:
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
@@ -36,8 +27,10 @@ class BaseTradeCalendar:
|
||||
if self.start_time and self.end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
|
||||
self.calendar = _calendar
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
|
||||
_trade_calendar = self.calendar[_start_index: _end_index + 1]
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(
|
||||
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
|
||||
)
|
||||
_trade_calendar = self.calendar[_start_index : _end_index + 1]
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
@@ -52,7 +45,7 @@ class BaseTradeCalendar:
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
@@ -64,6 +57,7 @@ class BaseTradeCalendar:
|
||||
def step(self):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
|
||||
class BaseEnv(BaseTradeCalendar):
|
||||
"""
|
||||
# Strategy framework document
|
||||
@@ -83,8 +77,10 @@ class BaseEnv(BaseTradeCalendar):
|
||||
):
|
||||
self.generate_report = update_report
|
||||
self.verbose = verbose
|
||||
super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
|
||||
|
||||
super(BaseEnv, self).__init__(
|
||||
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
|
||||
)
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
super(BaseEnv, self).reset(**kwargs)
|
||||
if trade_account:
|
||||
@@ -94,7 +90,7 @@ class BaseEnv(BaseTradeCalendar):
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self.trade_account.current}
|
||||
return init_state
|
||||
|
||||
|
||||
def execute(self, **kwargs):
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
@@ -104,23 +100,32 @@ class BaseEnv(BaseTradeCalendar):
|
||||
def get_report(self):
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
|
||||
|
||||
class SplitEnv(BaseEnv):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
self,
|
||||
step_bar,
|
||||
sub_env,
|
||||
sub_strategy,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
update_report=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.sub_env = sub_env
|
||||
self.sub_strategy = sub_strategy
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs)
|
||||
|
||||
super(SplitEnv, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
update_report=update_report,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
|
||||
if trade_account:
|
||||
@@ -129,9 +134,9 @@ class SplitEnv(BaseEnv):
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
#if self.track:
|
||||
# if self.track:
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
# episode_reward = 0
|
||||
super(SplitEnv, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
@@ -140,9 +145,11 @@ class SplitEnv(BaseEnv):
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
|
||||
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
self.trade_account.update_report(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {}
|
||||
return _obs, _info
|
||||
@@ -150,31 +157,40 @@ class SplitEnv(BaseEnv):
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [(_report,_positions), *sub_env.get_report()]
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
return [(_report, _positions), *sub_env.get_report()]
|
||||
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
update_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs)
|
||||
super(SimulatorEnv, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
update_report=update_report,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(SimulatorEnv, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange=trade_exchange
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
"""
|
||||
Return: obs, done, info
|
||||
Return: obs, done, info
|
||||
"""
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
@@ -184,7 +200,9 @@ class SimulatorEnv(BaseEnv):
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=self.trade_account)
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
order, trade_account=self.trade_account
|
||||
)
|
||||
trade_info.append((order, trade_val, trade_cost, trade_price))
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
@@ -214,7 +232,9 @@ class SimulatorEnv(BaseEnv):
|
||||
# do nothing
|
||||
pass
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
self.trade_account.update_report(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {"trade_info": trade_info}
|
||||
return _obs, _info
|
||||
@@ -222,9 +242,4 @@ class SimulatorEnv(BaseEnv):
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [
|
||||
{
|
||||
"report": _report,
|
||||
"positions": _positions
|
||||
}
|
||||
]
|
||||
return [{"report": _report, "positions": _positions}]
|
||||
|
||||
@@ -16,7 +16,6 @@ from ...log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
|
||||
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -101,14 +100,15 @@ class Exchange:
|
||||
self.min_cost = min_cost
|
||||
self.limit_threshold = limit_threshold
|
||||
|
||||
|
||||
self.extra_quote = extra_quote
|
||||
self.set_quote(codes, start_time, end_time)
|
||||
|
||||
def set_quote(self, codes, start_time, end_time):
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"])
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(
|
||||
subset=["$close"]
|
||||
)
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
if self.quote[self.deal_price].isna().any():
|
||||
@@ -168,7 +168,6 @@ class Exchange:
|
||||
is limtited
|
||||
"""
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
|
||||
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
@@ -180,7 +179,9 @@ class Exchange:
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time):
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
|
||||
stock_id, start_time, end_time
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -235,9 +236,13 @@ class Exchange:
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0]
|
||||
deal_price = sample_feature(
|
||||
self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last"
|
||||
).iloc[0]
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
|
||||
self.logger.warning(
|
||||
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
|
||||
)
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
@@ -274,7 +279,9 @@ class Exchange:
|
||||
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
@@ -377,7 +384,10 @@ class Exchange:
|
||||
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
):
|
||||
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
|
||||
value += (
|
||||
self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
* amount_dict[stock_id]
|
||||
)
|
||||
return value
|
||||
|
||||
def round_amount_by_trade_unit(self, deal_amount, factor):
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
|
||||
class BaseInterpreter:
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class ActionInterpreter:
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
return action
|
||||
|
||||
|
||||
class StateInterpreter:
|
||||
@staticmethod
|
||||
def interpret(state, **kwargs):
|
||||
return state
|
||||
return state
|
||||
|
||||
@@ -45,16 +45,7 @@ class Report:
|
||||
bench_value=None,
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
trade_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
turnover_rate,
|
||||
cost_rate,
|
||||
stock_value,
|
||||
bench_value
|
||||
]:
|
||||
if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]:
|
||||
raise ValueError(
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
|
||||
)
|
||||
@@ -108,5 +99,5 @@ class Report:
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
bench_value=r.loc[trade_time]["bench"]
|
||||
bench_value=r.loc[trade_time]["bench"],
|
||||
)
|
||||
|
||||
@@ -7,12 +7,10 @@ from .model_strategy import (
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
from .rule_strategy import(
|
||||
from .rule_strategy import (
|
||||
TWAPStrategy,
|
||||
SBBStrategyBase,
|
||||
SBBStrategyEMA,
|
||||
)
|
||||
|
||||
from .cost_control import (
|
||||
SoftTopkStrategy
|
||||
)
|
||||
from .cost_control import SoftTopkStrategy
|
||||
|
||||
@@ -53,7 +53,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange)
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
@@ -65,8 +67,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.stock_count = {}
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(TopkDropoutStrategy, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
@@ -94,7 +95,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
cur_n = 0
|
||||
res = []
|
||||
for si in reversed(l) if reverse else l:
|
||||
if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time):
|
||||
if self.trade_exchange.is_stock_tradable(
|
||||
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
res.append(si)
|
||||
cur_n += 1
|
||||
if cur_n >= n:
|
||||
@@ -105,7 +108,13 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return get_first_n(l, n, reverse=True)
|
||||
|
||||
def filter_stock(l):
|
||||
return [si for si in l if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time)]
|
||||
return [
|
||||
si
|
||||
for si in l
|
||||
if self.trade_exchange.is_stock_tradable(
|
||||
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
|
||||
)
|
||||
]
|
||||
|
||||
else:
|
||||
# Otherwise, the stock will make decision with out the stock tradable info
|
||||
@@ -166,11 +175,16 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index
|
||||
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
|
||||
if (
|
||||
self.stock_count[code] < self.thresh
|
||||
or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh
|
||||
):
|
||||
# can not sell this code
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
@@ -188,7 +202,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
# is order executable
|
||||
if self.trade_exchange.check_order(sell_order):
|
||||
sell_order_list.append(sell_order)
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(sell_order, position=current_temp)
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
sell_order, position=current_temp
|
||||
)
|
||||
# update cash
|
||||
cash += trade_val - trade_cost
|
||||
# sold
|
||||
@@ -213,10 +229,14 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
# value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock suspended
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
# buy order
|
||||
buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
|
||||
buy_price = self.trade_exchange.get_deal_price(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
)
|
||||
buy_amount = value / buy_price
|
||||
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
|
||||
buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
|
||||
@@ -231,17 +251,24 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
buy_order_list.append(buy_order)
|
||||
self.stock_count[code] = 1
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
|
||||
self.trade_exchange = trade_exchange
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
"""
|
||||
|
||||
@@ -81,10 +81,16 @@ class OrderGenWInteract(OrderGenerator):
|
||||
# calculate current_tradable_value
|
||||
current_amount_dict = current.get_stock_amount_dict()
|
||||
current_total_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=False
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
only_tradable=False,
|
||||
)
|
||||
current_tradable_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=True
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
only_tradable=True,
|
||||
)
|
||||
# add cash
|
||||
current_tradable_value += current.get_cash()
|
||||
@@ -97,7 +103,9 @@ class OrderGenWInteract(OrderGenerator):
|
||||
# value. Then just sell all the stocks
|
||||
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
|
||||
for stock_id in list(target_amount_dict.keys()):
|
||||
if trade_exchange.is_stock_tradable(stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time):
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
):
|
||||
del target_amount_dict[stock_id]
|
||||
else:
|
||||
# consider cost rate
|
||||
@@ -108,13 +116,13 @@ class OrderGenWInteract(OrderGenerator):
|
||||
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
|
||||
weight_position=target_weight_position,
|
||||
cash=current_tradable_value,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=target_amount_dict,
|
||||
current_position=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
@@ -161,7 +169,9 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
amount_dict = {}
|
||||
for stock_id in target_weight_position:
|
||||
# Current rule will ignore the stock that not hold and cannot be traded at predict date
|
||||
if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time):
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ from ..backtest.order import Order
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
@@ -19,7 +18,6 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(TopkDropoutStrategy, self).step()
|
||||
@@ -37,10 +35,12 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
order_list.append(_order)
|
||||
return order_list
|
||||
|
||||
|
||||
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
"""
|
||||
|
||||
TREND_MID = 0
|
||||
TREND_SHORT = 1
|
||||
TREND_LONG = 2
|
||||
@@ -50,11 +50,10 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_order_list:
|
||||
self.trade_amount = {}
|
||||
self.trade_trend = {}
|
||||
self.trade_trend = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
@@ -81,10 +80,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
order_list.append(_order)
|
||||
else:
|
||||
if self.trade_index % 2 == 1:
|
||||
if _pred_trend == self.TREND_SHORT and order.direction == order.SELL or _pred_trend == self.TREND_LONG and order.direction == order.BUY:
|
||||
if (
|
||||
_pred_trend == self.TREND_SHORT
|
||||
and order.direction == order.SELL
|
||||
or _pred_trend == self.TREND_LONG
|
||||
and order.direction == order.BUY
|
||||
):
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=2*self.trade_amount[(order.stock_id, order.direction)],
|
||||
amount=2 * self.trade_amount[(order.stock_id, order.direction)],
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
@@ -92,31 +96,37 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
)
|
||||
order_list.append(_order)
|
||||
else:
|
||||
if _pred_trend == self.TREND_SHORT and order.direction == order.BUY or _pred_trend == self.TREND_LONG and order.direction == order.SELL:
|
||||
if (
|
||||
_pred_trend == self.TREND_SHORT
|
||||
and order.direction == order.BUY
|
||||
or _pred_trend == self.TREND_LONG
|
||||
and order.direction == order.SELL
|
||||
):
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=2*self.trade_amount[(order.stock_id, order.direction)],
|
||||
amount=2 * self.trade_amount[(order.stock_id, order.direction)],
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
factor=order.factor,
|
||||
)
|
||||
)
|
||||
order_list.append(_order)
|
||||
if self.trade_index % 2 == 1:
|
||||
if self.trade_index % 2 == 1:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
return order_list
|
||||
|
||||
|
||||
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
**kwargs,
|
||||
@@ -139,22 +149,25 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if self.start_time and self.end_time:
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
|
||||
)
|
||||
signal_df = self._convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
if stock_id not in self.signal:
|
||||
return self.TREND_MID
|
||||
else:
|
||||
_sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last")
|
||||
_sample_signal = sample_feature(
|
||||
self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last"
|
||||
)
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
return self.TREND_MID
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
|
||||
@@ -126,7 +126,7 @@ class CalendarProvider(abc.ABC):
|
||||
_calendar = np.array(self.load_calendar(freq, future))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag_raw] = _calendar, _calendar_index
|
||||
|
||||
|
||||
if freq_sam is None:
|
||||
return _calendar, _calendar_index
|
||||
else:
|
||||
@@ -134,7 +134,6 @@ class CalendarProvider(abc.ABC):
|
||||
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
|
||||
H["c"][flag] = _calendar_sam, _calendar_sam_index
|
||||
return _calendar_sam, _calendar_sam_index
|
||||
|
||||
|
||||
def _uri(self, start_time, end_time, freq, future=False):
|
||||
"""Get the uri of calendar generation task."""
|
||||
@@ -560,7 +559,8 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
else:
|
||||
end_time = _calendar[-1]
|
||||
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
|
||||
return _calendar[si : ei + 1]
|
||||
return _calendar[si : ei + 1]
|
||||
|
||||
|
||||
class LocalInstrumentProvider(InstrumentProvider):
|
||||
"""Local instrument data provider class
|
||||
@@ -767,7 +767,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
self.conn = conn
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="calendar",
|
||||
request_content={
|
||||
|
||||
@@ -20,8 +20,9 @@ from ..contrib.backtest.env import BaseTradeCalendar
|
||||
- adjust_dates这个东西啥用
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
|
||||
|
||||
class BaseStrategy(BaseTradeCalendar):
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
raise NotImplementedError("generator_order_list is not implemented!")
|
||||
|
||||
@@ -29,12 +30,13 @@ class BaseStrategy(BaseTradeCalendar):
|
||||
class RuleStrategy(BaseStrategy):
|
||||
pass
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs):
|
||||
def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
|
||||
#pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
|
||||
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
|
||||
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
|
||||
def _convert_index_format(self, df):
|
||||
@@ -43,12 +45,11 @@ class ModelStrategy(BaseStrategy):
|
||||
return df
|
||||
|
||||
def _update_model(self):
|
||||
"""update pred score
|
||||
"""
|
||||
"""update pred score"""
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
class TradingEnhancement:
|
||||
def reset(self, trade_order_list=None):
|
||||
if trade_order_list:
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
|
||||
@@ -801,6 +801,7 @@ def fname_to_code(fname: str):
|
||||
fname = fname.lstrip(prefix)
|
||||
return fname
|
||||
|
||||
|
||||
########################## Sample ############################
|
||||
def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
@@ -810,16 +811,17 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if 9 <= hour <= 11:
|
||||
minute_index = (11 - hour)*60 + 30 - minute + 120
|
||||
minute_index = (11 - hour) * 60 + 30 - minute + 120
|
||||
elif 13 <= hour <= 15:
|
||||
minute_index = (15 - hour)*60 - minute
|
||||
minute_index = (15 - hour) * 60 - minute
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
@@ -838,32 +840,40 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
if raw_minutes > sam_minutes:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), calendar_raw)))
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(
|
||||
lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59),
|
||||
calendar_raw,
|
||||
)
|
||||
)
|
||||
)
|
||||
return _calendar_minute
|
||||
else:
|
||||
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw)))
|
||||
if freq_sam.endswith(("day", "d")):
|
||||
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
|
||||
return _calendar_day[(len(_calendar_day) + sam_days - 1)%sam_days::sam_days]
|
||||
return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days]
|
||||
|
||||
elif freq_sam.endswith(("week", "w")):
|
||||
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_week[(len(_calendar_week) + sam_weeks - 1)%sam_weeks::sam_weeks]
|
||||
return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks]
|
||||
|
||||
elif freq_sam.endswith(("month", "m")):
|
||||
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_month[(len(_calendar_month) + sam_months - 1)%sam_months::sam_months]
|
||||
return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
def parse_freq(freq):
|
||||
freq = freq.lower()
|
||||
search_obj =re.search("^([0-9]*)([a-z]+)", freq)
|
||||
search_obj = re.search("^([0-9]*)([a-z]+)", freq)
|
||||
if search_obj is None:
|
||||
raise ValueError("freq format is not supported")
|
||||
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
|
||||
@@ -881,9 +891,12 @@ def parse_freq(freq):
|
||||
try:
|
||||
_freq = _freq_format_dict.get(_freq)
|
||||
except KeyError:
|
||||
raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min")
|
||||
raise ValueError(
|
||||
"freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min"
|
||||
)
|
||||
return _count, _freq
|
||||
|
||||
|
||||
def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
@@ -893,16 +906,17 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
|
||||
minute_index = (hour - 9)*60 + minute - 30
|
||||
minute_index = (hour - 9) * 60 + minute - 30
|
||||
elif 13 <= hour < 15:
|
||||
minute_index = (hour - 13)*60 + minute + 120
|
||||
minute_index = (hour - 13) * 60 + minute + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
@@ -917,7 +931,11 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)))
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
|
||||
)
|
||||
)
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
return _calendar_minute
|
||||
@@ -937,7 +955,8 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
return _calendar_month[::sam_count]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
|
||||
_, norm_freq = parse_freq(freq)
|
||||
|
||||
@@ -963,23 +982,28 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
fields = fields if fields else slice(None)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields]
|
||||
feature = (
|
||||
feature.loc[selector_datetime, fields]
|
||||
if datetime_level
|
||||
else feature.loc[(slice(None), selector_datetime), fields]
|
||||
)
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs))
|
||||
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
@@ -988,7 +1012,5 @@ def sample_feature(feature, start_time=None, end_time=None, fields=None, method=
|
||||
return method_func(feature, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
return feature
|
||||
|
||||
@@ -254,13 +254,19 @@ class PortAnaRecord(SignalRecord):
|
||||
for report_dep, (report_normal, positions_normal) in enumerate(report_list):
|
||||
if report_dict is None:
|
||||
if self.risk_analysis_dep == report_dep:
|
||||
warnings.warn(f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`")
|
||||
warnings.warn(
|
||||
f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`"
|
||||
)
|
||||
continue
|
||||
|
||||
self.recorder.save_objects(**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
# analysis
|
||||
self.risk_analysis_dep == report_dep:
|
||||
if self.risk_analysis_dep == report_dep:
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
@@ -270,7 +276,9 @@ class PortAnaRecord(SignalRecord):
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
self.recorder.save_objects(**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{report_dep}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user