1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

black format

This commit is contained in:
bxdd
2021-04-29 02:29:29 +08:00
parent 49cdaf8f5d
commit f404a031f3
16 changed files with 275 additions and 172 deletions

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
@@ -70,7 +70,7 @@ if __name__ == "__main__":
},
},
}
# model initialization
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model.fit(dataset)
@@ -78,7 +78,7 @@ if __name__ == "__main__":
trade_start_time = "2017-01-31"
trade_end_time = "2018-01-31"
backtest_config={
backtest_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
@@ -90,7 +90,7 @@ if __name__ == "__main__":
"n_drop": 5,
},
},
"env":{
"env": {
"class": "SplitEnv",
"module_path": "qlib.contrib.backtest.env",
"kwargs": {
@@ -101,7 +101,7 @@ if __name__ == "__main__":
"kwargs": {
"step_bar": "day",
"verbose": True,
}
},
},
"sub_strategy": {
"class": "SBBStrategyEMA",
@@ -110,11 +110,11 @@ if __name__ == "__main__":
"step_bar": "day",
"freq": "day",
"instruments": "csi300",
}
}
}
},
},
},
},
"backtest":{
"backtest": {
"start_time": trade_start_time,
"end_time": trade_end_time,
"verbose": False,
@@ -125,8 +125,14 @@ if __name__ == "__main__":
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
},
}
report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False)
report_dict = backtest(
start_time=trade_start_time,
end_time=trade_end_time,
**backtest_config,
account=1e8,
deal_price="$close",
verbose=False,
)

View File

@@ -22,7 +22,7 @@ def get_exchange(
freq="day",
start_time=None,
end_time=None,
codes = "all",
codes="all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
@@ -89,6 +89,7 @@ def get_exchange(
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
@@ -103,6 +104,7 @@ def init_env_instance_by_config(env):
else:
return env
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
@@ -114,8 +116,8 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
setup_exchange(root_instance.sub_env, trade_exchange)
if hasattr(root_instance, "sub_strategy"):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)

View File

@@ -11,7 +11,6 @@ from .order import Order
from ...utils import parse_freq, sample_feature
"""
rtn & earning in the Account
rtn:
@@ -87,7 +86,7 @@ class Account:
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
else:
raise ValueError(f"benchmark freq {freq} is not supported")
raise ValueError(f"benchmark freq {freq} is not supported")
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
@@ -95,20 +94,20 @@ class Account:
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def cal_change(x):
return x.prod() - 1
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
def reset(self, benchmark=None, freq=None,**kwargs):
def reset(self, benchmark=None, freq=None, **kwargs):
if benchmark:
self.benchmark = benchmark
if freq:
self.freq = freq
if self.freq and self.benchmark and (freq or benchmark)
if self.freq and self.benchmark and (freq or benchmark):
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
for k, v in kwargs:
if hasattr(k):
setattr(k, v)
def get_positions(self):
return self.positions
@@ -203,7 +202,7 @@ class Account:
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=now_stock_value,
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time),
)
# set now_account_value to position
self.current.position["now_account_value"] = now_account_value

View File

@@ -7,6 +7,7 @@ import pandas as pd
from .account import Account
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time)
@@ -17,10 +18,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
while not trade_env.finished():
_order_list = trade_strategy.generate_order_list(**trade_state)
trade_state, trade_info = trade_env.execute(_order_list)
report_df = trade_account.report.generate_report_dataframe()
positions = trade_account.get_positions()
report_dict = {"report_df": report_df, "positions": positions}
return report_dict

View File

@@ -1,5 +1,3 @@
import re
import json
import copy
@@ -14,15 +12,8 @@ from .report import Report
from .order import Order
class BaseTradeCalendar:
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
**kwargs
):
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time)
@@ -36,8 +27,10 @@ class BaseTradeCalendar:
if self.start_time and self.end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
self.calendar = _calendar
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
_trade_calendar = self.calendar[_start_index: _end_index + 1]
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
)
_trade_calendar = self.calendar[_start_index : _end_index + 1]
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
@@ -52,7 +45,7 @@ class BaseTradeCalendar:
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
@@ -64,6 +57,7 @@ class BaseTradeCalendar:
def step(self):
self.trade_index = self.trade_index + 1
class BaseEnv(BaseTradeCalendar):
"""
# Strategy framework document
@@ -83,8 +77,10 @@ class BaseEnv(BaseTradeCalendar):
):
self.generate_report = update_report
self.verbose = verbose
super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
super(BaseEnv, self).__init__(
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
)
def reset(self, trade_account=None, **kwargs):
super(BaseEnv, self).reset(**kwargs)
if trade_account:
@@ -94,7 +90,7 @@ class BaseEnv(BaseTradeCalendar):
def get_init_state(self):
init_state = {"current": self.trade_account.current}
return init_state
def execute(self, **kwargs):
raise NotImplementedError("execute is not implemented!")
@@ -104,23 +100,32 @@ class BaseEnv(BaseTradeCalendar):
def get_report(self):
raise NotImplementedError("get_report is not implemented!")
class SplitEnv(BaseEnv):
def __init__(
self,
step_bar,
self,
step_bar,
sub_env,
sub_strategy,
start_time=None,
end_time=None,
start_time=None,
end_time=None,
trade_account=None,
update_report=False,
verbose=False,
**kwargs
**kwargs,
):
self.sub_env = sub_env
self.sub_strategy = sub_strategy
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs)
super(SplitEnv, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
update_report=update_report,
verbose=verbose,
**kwargs,
)
def reset(self, trade_account=None, **kwargs):
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
if trade_account:
@@ -129,9 +134,9 @@ class SplitEnv(BaseEnv):
def execute(self, order_list, **kwargs):
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
#if self.track:
# if self.track:
# yield action
#episode_reward = 0
# episode_reward = 0
super(SplitEnv, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
@@ -140,9 +145,11 @@ class SplitEnv(BaseEnv):
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
if self.generate_report:
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
self.trade_account.update_report(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
)
_obs = {"current": self.trade_account.current}
_info = {}
return _obs, _info
@@ -150,31 +157,40 @@ class SplitEnv(BaseEnv):
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [(_report,_positions), *sub_env.get_report()]
class SimulatorEnv(BaseEnv):
return [(_report, _positions), *sub_env.get_report()]
class SimulatorEnv(BaseEnv):
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_account=None,
self,
step_bar,
start_time=None,
end_time=None,
trade_account=None,
trade_exchange=None,
update_report=False,
verbose=False,
**kwargs,
):
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs)
super(SimulatorEnv, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
trade_exchange=trade_exchange,
update_report=update_report,
verbose=verbose,
**kwargs,
)
def reset(self, trade_exchange=None, **kwargs):
super(SimulatorEnv, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange=trade_exchange
self.trade_exchange = trade_exchange
def execute(self, order_list, **kwargs):
"""
Return: obs, done, info
Return: obs, done, info
"""
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
@@ -184,7 +200,9 @@ class SimulatorEnv(BaseEnv):
for order in order_list:
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=self.trade_account)
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
order, trade_account=self.trade_account
)
trade_info.append((order, trade_val, trade_cost, trade_price))
if self.verbose:
if order.direction == Order.SELL: # sell
@@ -214,7 +232,9 @@ class SimulatorEnv(BaseEnv):
# do nothing
pass
if self.generate_report:
self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
self.trade_account.update_report(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
)
_obs = {"current": self.trade_account.current}
_info = {"trade_info": trade_info}
return _obs, _info
@@ -222,9 +242,4 @@ class SimulatorEnv(BaseEnv):
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [
{
"report": _report,
"positions": _positions
}
]
return [{"report": _report, "positions": _positions}]

View File

@@ -16,7 +16,6 @@ from ...log import get_module_logger
from .order import Order
class Exchange:
def __init__(
self,
@@ -101,14 +100,15 @@ class Exchange:
self.min_cost = min_cost
self.limit_threshold = limit_threshold
self.extra_quote = extra_quote
self.set_quote(codes, start_time, end_time)
def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"])
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(
subset=["$close"]
)
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
@@ -168,7 +168,6 @@ class Exchange:
is limtited
"""
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
@@ -180,7 +179,9 @@ class Exchange:
def is_stock_tradable(self, stock_id, start_time, end_time):
# check if stock can be traded
# same as check in check_order
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time):
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
stock_id, start_time, end_time
):
return False
else:
return True
@@ -235,9 +236,13 @@ class Exchange:
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0]
deal_price = sample_feature(
self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last"
).iloc[0]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
)
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time)
return deal_price
@@ -274,7 +279,9 @@ class Exchange:
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
stock_id=stock_id, start_time=start_time, end_time=end_time
):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
@@ -377,7 +384,10 @@ class Exchange:
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
):
value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id]
value += (
self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time)
* amount_dict[stock_id]
)
return value
def round_amount_by_trade_unit(self, deal_amount, factor):

View File

@@ -1,15 +1,16 @@
class BaseInterpreter:
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter:
@staticmethod
def interpret(action, **kwargs):
return action
class StateInterpreter:
@staticmethod
def interpret(state, **kwargs):
return state
return state

View File

@@ -45,16 +45,7 @@ class Report:
bench_value=None,
):
# check data
if None in [
trade_time,
account_value,
cash,
return_rate,
turnover_rate,
cost_rate,
stock_value,
bench_value
]:
if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]:
raise ValueError(
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
)
@@ -108,5 +99,5 @@ class Report:
turnover_rate=r.loc[trade_time]["turnover"],
cost_rate=r.loc[trade_time]["cost"],
stock_value=r.loc[trade_time]["value"],
bench_value=r.loc[trade_time]["bench"]
bench_value=r.loc[trade_time]["bench"],
)

View File

@@ -7,12 +7,10 @@ from .model_strategy import (
WeightStrategyBase,
)
from .rule_strategy import(
from .rule_strategy import (
TWAPStrategy,
SBBStrategyBase,
SBBStrategyEMA,
)
from .cost_control import (
SoftTopkStrategy
)
from .cost_control import SoftTopkStrategy

View File

@@ -53,7 +53,9 @@ class TopkDropoutStrategy(ModelStrategy):
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange)
super(TopkDropoutStrategy, self).__init__(
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange
)
self.topk = topk
self.n_drop = n_drop
self.method_sell = method_sell
@@ -65,8 +67,7 @@ class TopkDropoutStrategy(ModelStrategy):
self.stock_count = {}
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
def reset(self, trade_exchange=None, **kwargs):
super(TopkDropoutStrategy, self).reset(**kwargs)
if trade_exchange:
@@ -94,7 +95,9 @@ class TopkDropoutStrategy(ModelStrategy):
cur_n = 0
res = []
for si in reversed(l) if reverse else l:
if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time):
if self.trade_exchange.is_stock_tradable(
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
):
res.append(si)
cur_n += 1
if cur_n >= n:
@@ -105,7 +108,13 @@ class TopkDropoutStrategy(ModelStrategy):
return get_first_n(l, n, reverse=True)
def filter_stock(l):
return [si for si in l if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time)]
return [
si
for si in l
if self.trade_exchange.is_stock_tradable(
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
)
]
else:
# Otherwise, the stock will make decision with out the stock tradable info
@@ -166,11 +175,16 @@ class TopkDropoutStrategy(ModelStrategy):
buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index
for code in current_stock_list:
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
if not self.trade_exchange.is_stock_tradable(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
):
continue
if code in sell:
# check hold limit
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
if (
self.stock_count[code] < self.thresh
or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh
):
# can not sell this code
# no buy signal, but the stock is kept
self.stock_count[code] += 1
@@ -188,7 +202,9 @@ class TopkDropoutStrategy(ModelStrategy):
# is order executable
if self.trade_exchange.check_order(sell_order):
sell_order_list.append(sell_order)
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(sell_order, position=current_temp)
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
sell_order, position=current_temp
)
# update cash
cash += trade_val - trade_cost
# sold
@@ -213,10 +229,14 @@ class TopkDropoutStrategy(ModelStrategy):
# value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
for code in buy:
# check is stock suspended
if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time):
if not self.trade_exchange.is_stock_tradable(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
):
continue
# buy order
buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
buy_price = self.trade_exchange.get_deal_price(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
)
buy_amount = value / buy_price
factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
@@ -231,17 +251,24 @@ class TopkDropoutStrategy(ModelStrategy):
buy_order_list.append(buy_order)
self.stock_count[code] = 1
return sell_order_list + buy_order_list
class WeightStrategyBase(ModelStrategy):
def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs):
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
**kwargs,
):
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
self.trade_exchange = trade_exchange
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
"""

View File

@@ -81,10 +81,16 @@ class OrderGenWInteract(OrderGenerator):
# calculate current_tradable_value
current_amount_dict = current.get_stock_amount_dict()
current_total_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=False
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
only_tradable=False,
)
current_tradable_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=True
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
only_tradable=True,
)
# add cash
current_tradable_value += current.get_cash()
@@ -97,7 +103,9 @@ class OrderGenWInteract(OrderGenerator):
# value. Then just sell all the stocks
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
for stock_id in list(target_amount_dict.keys()):
if trade_exchange.is_stock_tradable(stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time):
if trade_exchange.is_stock_tradable(
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
):
del target_amount_dict[stock_id]
else:
# consider cost rate
@@ -108,13 +116,13 @@ class OrderGenWInteract(OrderGenerator):
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
weight_position=target_weight_position,
cash=current_tradable_value,
trade_start_time=trade_start_time,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
)
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=target_amount_dict,
current_position=current_amount_dict,
trade_start_time=trade_start_time,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
)
return order_list
@@ -161,7 +169,9 @@ class OrderGenWOInteract(OrderGenerator):
amount_dict = {}
for stock_id in target_weight_position:
# Current rule will ignore the stock that not hold and cannot be traded at predict date
if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time):
if trade_exchange.is_stock_tradable(
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
):
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
)

View File

@@ -11,7 +11,6 @@ from ..backtest.order import Order
class TWAPStrategy(RuleStrategy, TradingEnhancement):
def reset(self, trade_order_list=None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
@@ -19,7 +18,6 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
def generate_order_list(self, **kwargs):
super(TopkDropoutStrategy, self).step()
@@ -37,10 +35,12 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
order_list.append(_order)
return order_list
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
"""
TREND_MID = 0
TREND_SHORT = 1
TREND_LONG = 2
@@ -50,11 +50,10 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list:
self.trade_amount = {}
self.trade_trend = {}
self.trade_trend = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
@@ -81,10 +80,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
order_list.append(_order)
else:
if self.trade_index % 2 == 1:
if _pred_trend == self.TREND_SHORT and order.direction == order.SELL or _pred_trend == self.TREND_LONG and order.direction == order.BUY:
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.SELL
or _pred_trend == self.TREND_LONG
and order.direction == order.BUY
):
_order = Order(
stock_id=order.stock_id,
amount=2*self.trade_amount[(order.stock_id, order.direction)],
amount=2 * self.trade_amount[(order.stock_id, order.direction)],
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
@@ -92,31 +96,37 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
)
order_list.append(_order)
else:
if _pred_trend == self.TREND_SHORT and order.direction == order.BUY or _pred_trend == self.TREND_LONG and order.direction == order.SELL:
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.BUY
or _pred_trend == self.TREND_LONG
and order.direction == order.SELL
):
_order = Order(
stock_id=order.stock_id,
amount=2*self.trade_amount[(order.stock_id, order.direction)],
amount=2 * self.trade_amount[(order.stock_id, order.direction)],
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
factor=order.factor,
)
)
order_list.append(_order)
if self.trade_index % 2 == 1:
if self.trade_index % 2 == 1:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
class SBBStrategyEMA(SBBStrategyBase):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
"""
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
self,
step_bar,
start_time=None,
end_time=None,
instruments="csi300",
freq="day",
**kwargs,
@@ -139,22 +149,25 @@ class SBBStrategyEMA(SBBStrategyBase):
if self.start_time and self.end_time:
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
)
signal_df = self._convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
if stock_id not in self.signal:
return self.TREND_MID
else:
_sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last")
_sample_signal = sample_feature(
self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last"
)
if _sample_signal is None or _sample_signal.iloc[0] == 0:
return self.TREND_MID
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
else:
return self.TREND_SHORT

View File

@@ -126,7 +126,7 @@ class CalendarProvider(abc.ABC):
_calendar = np.array(self.load_calendar(freq, future))
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
H["c"][flag_raw] = _calendar, _calendar_index
if freq_sam is None:
return _calendar, _calendar_index
else:
@@ -134,7 +134,6 @@ class CalendarProvider(abc.ABC):
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
H["c"][flag] = _calendar_sam, _calendar_sam_index
return _calendar_sam, _calendar_sam_index
def _uri(self, start_time, end_time, freq, future=False):
"""Get the uri of calendar generation task."""
@@ -560,7 +559,8 @@ class LocalCalendarProvider(CalendarProvider):
else:
end_time = _calendar[-1]
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
return _calendar[si : ei + 1]
return _calendar[si : ei + 1]
class LocalInstrumentProvider(InstrumentProvider):
"""Local instrument data provider class
@@ -767,7 +767,7 @@ class ClientCalendarProvider(CalendarProvider):
self.conn = conn
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
self.conn.send_request(
request_type="calendar",
request_content={

View File

@@ -20,8 +20,9 @@ from ..contrib.backtest.env import BaseTradeCalendar
- adjust_dates这个东西啥用
- label和freq和strategy的bar分离这个如何决策呢
"""
class BaseStrategy(BaseTradeCalendar):
def generate_order_list(self, **kwargs):
raise NotImplementedError("generator_order_list is not implemented!")
@@ -29,12 +30,13 @@ class BaseStrategy(BaseTradeCalendar):
class RuleStrategy(BaseStrategy):
pass
class ModelStrategy(BaseStrategy):
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs):
def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs):
self.model = model
self.dataset = dataset
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
#pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
def _convert_index_format(self, df):
@@ -43,12 +45,11 @@ class ModelStrategy(BaseStrategy):
return df
def _update_model(self):
"""update pred score
"""
"""update pred score"""
raise NotImplementedError("_update_model is not implemented!")
class TradingEnhancement:
def reset(self, trade_order_list=None):
if trade_order_list:
self.trade_order_list = trade_order_list

View File

@@ -801,6 +801,7 @@ def fname_to_code(fname: str):
fname = fname.lstrip(prefix)
return fname
########################## Sample ############################
def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
"""
@@ -810,16 +811,17 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
if freq_sam.endswith(("minute", "min")):
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if 9 <= hour <= 11:
minute_index = (11 - hour)*60 + 30 - minute + 120
minute_index = (11 - hour) * 60 + 30 - minute + 120
elif 13 <= hour <= 15:
minute_index = (15 - hour)*60 - minute
minute_index = (15 - hour) * 60 - minute
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
@@ -838,32 +840,40 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
if raw_minutes > sam_minutes:
raise ValueError("raw freq must be higher than sample freq")
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), calendar_raw)))
_calendar_minute = np.unique(
list(
map(
lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59),
calendar_raw,
)
)
)
return _calendar_minute
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw)))
if freq_sam.endswith(("day", "d")):
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
return _calendar_day[(len(_calendar_day) + sam_days - 1)%sam_days::sam_days]
return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days]
elif freq_sam.endswith(("week", "w")):
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0]
return _calendar_week[(len(_calendar_week) + sam_weeks - 1)%sam_weeks::sam_weeks]
return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks]
elif freq_sam.endswith(("month", "m")):
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0]
return _calendar_month[(len(_calendar_month) + sam_months - 1)%sam_months::sam_months]
return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months]
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def parse_freq(freq):
freq = freq.lower()
search_obj =re.search("^([0-9]*)([a-z]+)", freq)
search_obj = re.search("^([0-9]*)([a-z]+)", freq)
if search_obj is None:
raise ValueError("freq format is not supported")
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
@@ -881,9 +891,12 @@ def parse_freq(freq):
try:
_freq = _freq_format_dict.get(_freq)
except KeyError:
raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min")
raise ValueError(
"freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min"
)
return _count, _freq
def sample_calendar(calendar_raw, freq_raw, freq_sam):
"""
freq_raw : "min" or "day"
@@ -893,16 +906,17 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
if not len(calendar_raw):
return calendar_raw
if freq_sam == "minute":
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
minute_index = (hour - 9)*60 + minute - 30
minute_index = (hour - 9) * 60 + minute - 30
elif 13 <= hour < 15:
minute_index = (hour - 13)*60 + minute + 120
minute_index = (hour - 13) * 60 + minute + 120
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
@@ -917,7 +931,11 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
else:
if raw_count > sam_count:
raise ValueError("raw freq must be higher than sample freq")
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)))
_calendar_minute = np.unique(
list(
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
)
)
if calendar_raw[0] > _calendar_minute[0]:
_calendar_minute[0] = calendar_raw[0]
return _calendar_minute
@@ -937,7 +955,8 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
return _calendar_month[::sam_count]
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
_, norm_freq = parse_freq(freq)
@@ -963,23 +982,28 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
selector_datetime = slice(start_time, end_time)
fields = fields if fields else slice(None)
from ..data.dataset.utils import get_level_index
datetime_level = get_level_index(feature, level="datetime") == 0
if isinstance(feature, pd.Series):
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
elif isinstance(feature, pd.DataFrame):
feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields]
feature = (
feature.loc[selector_datetime, fields]
if datetime_level
else feature.loc[(slice(None), selector_datetime), fields]
)
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):
if callable(method):
method_func = method
return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs))
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
elif isinstance(method, str):
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:
@@ -988,7 +1012,5 @@ def sample_feature(feature, start_time=None, end_time=None, fields=None, method=
return method_func(feature, **method_kwargs)
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature
return feature

View File

@@ -254,13 +254,19 @@ class PortAnaRecord(SignalRecord):
for report_dep, (report_normal, positions_normal) in enumerate(report_list):
if report_dict is None:
if self.risk_analysis_dep == report_dep:
warnings.warn(f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`")
warnings.warn(
f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`"
)
continue
self.recorder.save_objects(**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
)
self.recorder.save_objects(
**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
# analysis
self.risk_analysis_dep == report_dep:
if self.risk_analysis_dep == report_dep:
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
@@ -270,7 +276,9 @@ class PortAnaRecord(SignalRecord):
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
logger.info(
f"Portfolio analysis record 'port_analysis_{report_dep}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)