1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

fix comments

This commit is contained in:
bxdd
2021-05-13 00:33:57 +08:00
parent 621cb243c2
commit 07eaada31e
14 changed files with 294 additions and 185 deletions

View File

@@ -122,7 +122,6 @@ if __name__ == "__main__":
"benchmark": benchmark,
"exchange_kwargs": {
"freq": "day",
"verbose": False,
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,

View File

@@ -1,16 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .account import Account
from .exchange import Exchange
from .executor import BaseExecutor
from .backtest import backtest as backtest_func
import inspect
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
from .faculty import common_faculty
logger = get_module_logger("backtest caller")
@@ -28,7 +30,6 @@ def get_exchange(
trade_unit=None,
limit_threshold=None,
deal_price=None,
shift=1,
):
"""get_exchange
@@ -88,28 +89,26 @@ def get_exchange(
return init_instance_by_config(exchange, accept_types=Exchange)
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
root_instance.reset(trade_exchange=trade_exchange)
else:
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
root_instance.reset(trade_exchange=trade_exchange)
if hasattr(root_instance, "sub_env"):
setup_exchange(root_instance.sub_env, trade_exchange)
if hasattr(root_instance, "sub_strategy"):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}):
trade_account = Account(
init_cash=account,
benchmark_config={
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
)
trade_exchange = get_exchange(**exchange_kwargs)
common_faculty.update(
trade_account=trade_account,
trade_exchange=trade_exchange,
)
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}):
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
trade_exchange = get_exchange(**exchange_kwargs)
setup_exchange(trade_env, trade_exchange)
setup_exchange(trade_strategy, trade_exchange)
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env)
return report_dict

View File

@@ -30,48 +30,53 @@ rtn & earning in the Account
class Account:
def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
self.init_vars(init_cash, benchmark, start_time, end_time)
def __init__(self, init_cash, freq: str = "day", benchmark_config: dict = {}):
self.init_vars(init_cash, freq, benchmark_config)
def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
"""
Parameters
----------
- benchmark: str/list/pd.Series
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
freq : str
frequency of trading bar, used for updating hold count of trading bar
benchmark_config : dict
config of benchmark, may including the following arguments:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent start time of benchmark, by default None
- end_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent end time of benchmark, by default None
"""
# init cash
self.init_cash = init_cash
self.benchmark = benchmark
self.start_time = start_time
self.end_time = end_time
self.freq = freq
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(benchmark_config, freq)
self.current = Position(cash=init_cash)
self.positions = {}
self.rtn = 0
self.ct = 0
self.to = 0
self.val = 0
self.earning = 0
self.report = Report()
if freq and benchmark:
self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq)
self._reset_report()
def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None):
def _cal_benchmark(self, benchmark_config, freq):
benchmark = benchmark_config.get("benchmark", "SH000300")
if isinstance(benchmark, pd.Series):
return benchmark
else:
start_time = benchmark_config.get("start_time", None)
end_time = benchmark_config.get("end_time", None)
if freq is None:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
@@ -100,19 +105,25 @@ class Account:
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
return 0 if _ret is None else _ret
def reset(self, benchmark=None, freq=None, **kwargs):
if benchmark:
self.benchmark = benchmark
if freq:
def _reset_freq(self, freq):
"""reset frequency"""
if freq != self.freq:
self.freq = freq
if self.freq and self.benchmark and (freq or benchmark):
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
else:
warnings.warn(f"reser error, attribute {k} is not found!")
def _reset_report(self):
self.report = Report()
self.positions = {}
self.rtn = 0
self.ct = 0
self.to = 0
self.val = 0
self.earning = 0
def reset(self, freq=None, init_report: bool = False):
self._reset_freq(freq)
if init_report:
self._reset_report()
def get_positions(self):
return self.positions
@@ -155,7 +166,10 @@ class Account:
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report):
def update_bar_count(self):
self.current.add_count_all(bar=self.freq)
def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange):
"""
start_time: pd.TimeStamp
end_time: pd.TimeStamp
@@ -171,9 +185,6 @@ class Account:
:return: None
"""
# update price for stock in the position and the profit from changed_price
self.current.add_count_all(bar=self.freq)
if update_report is None:
return
stock_list = self.current.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0

View File

@@ -1,13 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .account import Account
def backtest(start_time, end_time, trade_strategy, trade_env):
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time)
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_env.reset(start_time=start_time, end_time=end_time)
trade_strategy.reset(start_time=start_time, end_time=end_time)
_execute_state = trade_env.get_init_state()

View File

@@ -1,18 +1,25 @@
import copy
import warnings
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
from typing import Union
from ...data.data import Cal
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...utils.sample import get_sample_freq_calendar, parse_freq
from .report import Report
from .order import Order
from .account import Account
from .exchange import Exchange
from .faculty import common_faculty
class BaseTradeCalendar:
"""
Base class providing trading calendar
- BaseStrategy and BaseExecutor should inherited from this class
"""
def __init__(
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
@@ -30,16 +37,13 @@ class BaseTradeCalendar:
"""
self.step_bar = step_bar
self.start_time = pd.Timestamp(start_time)
self.end_time = pd.Timestamp(end_time)
self.reset(start_time=start_time, end_time=end_time)
def _reset_trade_calendar(self, start_time, end_time):
if not start_time and not end_time:
return
if start_time:
self.start_time = pd.Timestamp(start_time)
if end_time:
self.end_time = pd.Timestamp(end_time)
if self.start_time and self.end_time:
"""reset trade calendar"""
if start_time and end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
self.calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(
@@ -50,17 +54,19 @@ class BaseTradeCalendar:
self.trade_len = _end_index - _start_index + 1
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.")
def reset(self, start_time=None, end_time=None, **kwargs):
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
def reset(self, start_time=None, end_time=None):
"""
Reset start\end time of trading, and reset trading calendar
"""
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
else:
warnings.warn(f"reser error, attribute {k} is not found!")
if start_time:
self.start_time = pd.Timestamp(start_time)
if end_time:
self.end_time = pd.Timestamp(end_time)
if self.start_time and self.end_time and (start_time or end_time):
self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time)
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
@@ -87,6 +93,7 @@ class BaseExecutor(BaseTradeCalendar):
trade_account: Account = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
**kwargs,
):
"""
@@ -94,23 +101,30 @@ class BaseExecutor(BaseTradeCalendar):
----------
trade_account : Account, optional
trade account for trading, by default None
If `trade_account` is None, it must be reset before trading
- If `trade_account` is None, self.trade_account will be set with common_faculty
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
whether to print log, by default False
whether to print trading info, by default False
track_data : bool, optional
whether to generate order_list, will be used when making data for multi-level training
- If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data`
- Else, `order_list` will not be generated
"""
super(BaseExecutor, self).__init__(
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
)
super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs)
self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account)
self.trade_account.reset(freq=self.step_bar, init_report=True)
self.generate_report = generate_report
self.verbose = verbose
self.track_data = track_data
def reset(self, trade_account=None, **kwargs):
def reset(self, track_data: bool = None, **kwargs):
"""
Reset `track_data`, will be used when making data for multi-level training
"""
super(BaseExecutor, self).reset(**kwargs)
if trade_account:
self.trade_account = trade_account
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
if track_data is not None:
self.track_data = track_data
def get_init_state(self):
init_state = {"current": self.trade_account.current}
@@ -127,6 +141,8 @@ class BaseExecutor(BaseTradeCalendar):
class SplitExecutor(BaseExecutor):
from ...strategy.base import BaseStrategy
def __init__(
self,
step_bar: str,
@@ -138,6 +154,7 @@ class SplitExecutor(BaseExecutor):
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
**kwargs,
):
"""
@@ -155,40 +172,55 @@ class SplitExecutor(BaseExecutor):
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
trade_exchange=trade_exchange,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
**kwargs,
)
if generate_report:
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy)
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy)
super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs)
if trade_account:
self.sub_env.reset(trade_account=copy.copy(trade_account))
if trade_exchange:
self.trade_exchange = trade_exchange
def execute(self, order_list):
super(SplitExecutor, self).step()
def _init_sub_trading(self, order_list):
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(_execute_state)
_execute_state = self.sub_env.execute(order_list=_order_list)
sub_execute_state = self.sub_env.get_init_state()
return sub_execute_state
self.trade_account.update_bar_end(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
update_report=self.generate_report,
)
_execute_state = {"current": self.trade_account.current}
return _execute_state
def _update_trade_account(self):
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.trade_account.update_bar_count()
if self.generate_report:
self.trade_account.update_bar_report(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
)
def execute(self, order_list):
super(SplitExecutor, self).step()
self._init_sub_trading(order_list)
sub_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
sub_execute_state = self.sub_env.execute(order_list=_order_list)
self._update_trade_account()
return {"current": self.trade_account.current}
def get_data(self, order_list):
if self.track_data:
yield order_list
super(SplitExecutor, self).step()
self._init_sub_trading(order_list)
sub_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list)
self._update_trade_account()
return {"current": self.trade_account.current}
def get_report(self):
sub_env_report_dict = self.sub_env.get_report()
@@ -203,13 +235,14 @@ class SplitExecutor(BaseExecutor):
class SimulatorExecutor(BaseExecutor):
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_account=None,
trade_exchange=None,
generate_report=False,
verbose=False,
step_bar: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
**kwargs,
):
"""
@@ -223,16 +256,12 @@ class SimulatorExecutor(BaseExecutor):
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
trade_exchange=trade_exchange,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
**kwargs,
)
def reset(self, trade_exchange=None, **kwargs):
super(SimulatorExecutor, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange = trade_exchange
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
def execute(self, order_list):
super(SimulatorExecutor, self).step()
@@ -276,14 +305,17 @@ class SimulatorExecutor(BaseExecutor):
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
# do nothing
pass
self.trade_account.update_bar_end(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
update_report=self.generate_report,
)
_execute_state = {"current": self.trade_account.current, "trade_info": trade_info}
return _execute_state
self.trade_account.update_bar_count()
if self.generate_report:
self.trade_account.update_bar_report(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
)
return {"current": self.trade_account.current, "trade_info": trade_info}
def get_report(self):
if self.generate_report:

View File

@@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class Faculty:
def __init__(self):
self.__dict__["_faculty"] = dict()
def __getitem__(self, key):
return self.__dict__["_faculty"][key]
def __getattr__(self, attr):
if attr in self.__dict__["_faculty"]:
return self.__dict__["_faculty"][attr]
raise AttributeError(f"No such {attr} in self._faculty")
def __setitem__(self, key, value):
self.__dict__["_faculty"][key] = value
def __setattr__(self, attr, value):
self.__dict__["_faculty"][attr] = value
def update(self, *args, **kwargs):
self.__dict__["_faculty"].update(*args, **kwargs)
common_faculty = Faculty()

View File

@@ -2,12 +2,27 @@
# Licensed under the MIT License.
from .order_generator import OrderGenWInteract
from .model_strategy import WeightStrategyBase
import copy
class SoftTopkStrategy(WeightStrategyBase):
def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"):
def __init__(
self,
step_bar,
model,
dataset,
topk,
start_time=None,
end_time=None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
max_sold_weight=1.0,
risk_degree=0.95,
buy_method="first_fill",
**kwargs,
):
"""Parameter
topk : int
top-N stocks to buy
@@ -17,13 +32,15 @@ class SoftTopkStrategy(WeightStrategyBase):
rank_fill: assign the weight stocks that rank high first(1/topk max)
average_fill: assign the weight to the stocks rank high averagely.
"""
super().__init__()
super(SoftTopkStrategy, self).__init__(
step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange
)
self.topk = topk
self.max_sold_weight = max_sold_weight
self.risk_degree = risk_degree
self.buy_method = buy_method
def get_risk_degree(self, trade_index):
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing

View File

@@ -6,6 +6,7 @@ import pandas as pd
from ...utils.sample import sample_feature
from ...strategy.base import ModelStrategy
from ..backtest.order import Order
from ..backtest.faculty import common_faculty
from .order_generator import OrderGenWInteract
@@ -50,9 +51,8 @@ class TopkDropoutStrategy(ModelStrategy):
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.topk = topk
self.n_drop = n_drop
self.method_sell = method_sell
@@ -61,11 +61,6 @@ class TopkDropoutStrategy(ModelStrategy):
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
def reset(self, trade_exchange=None, **kwargs):
super(TopkDropoutStrategy, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -164,7 +159,7 @@ class TopkDropoutStrategy(ModelStrategy):
# Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]
print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
# print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
# print("flag", len(sell), len(buy), self.topk, len(last))
for code in current_stock_list:
if not self.trade_exchange.is_stock_tradable(
@@ -242,20 +237,13 @@ class WeightStrategyBase(ModelStrategy):
trade_exchange=None,
**kwargs,
):
super(WeightStrategyBase, self).__init__(
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def reset(self, trade_exchange=None, **kwargs):
super(WeightStrategyBase, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.

View File

@@ -173,7 +173,9 @@ class OrderGenWOInteract(OrderGenerator):
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
):
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
risk_total_value
* target_weight_position[stock_id]
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
)
elif stock_id in current_stock:
amount_dict[stock_id] = (

View File

@@ -9,6 +9,7 @@ from ...data.data import D
from ...data.dataset.utils import convert_index_format
from ...strategy.base import RuleStrategy, OrderEnhancement
from ..backtest.order import Order
from ..backtest.faculty import common_faculty
class TWAPStrategy(RuleStrategy, OrderEnhancement):
@@ -18,16 +19,17 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
**kwargs,
):
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
def reset(self, trade_order_list: list = None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_exchange:
self.trade_exchange = trade_exchange
if trade_order_list:
if trade_order_list is not None:
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
@@ -82,15 +84,16 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
**kwargs,
):
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
def reset(self, trade_order_list=None, **kwargs):
super(SBBStrategyBase, self).reset(**kwargs)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_exchange:
self.trade_exchange = trade_exchange
if trade_order_list is not None:
self.trade_trend = {}
self.trade_amount = {}
@@ -217,11 +220,12 @@ class SBBStrategyEMA(SBBStrategyBase):
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
instruments="csi300",
freq="day",
**kwargs,
):
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
self.instruments = "all"
@@ -229,9 +233,9 @@ class SBBStrategyEMA(SBBStrategyBase):
self.instruments = D.instruments(instruments)
self.freq = freq
def reset(self, start_time=None, end_time=None, **kwargs):
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
if self.start_time and self.end_time:
def _reset_trade_calendar(self, start_time=None, end_time=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
if start_time and end_time:
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
signal_df = D.features(

View File

@@ -7,6 +7,8 @@ from ..contrib.backtest.executor import BaseExecutor
class BaseRLEnv:
"""Base environment for reinforcement learning"""
def reset(self, **kwargs):
raise NotImplementedError("reset is not implemented!")

View File

@@ -3,18 +3,48 @@
class BaseInterpreter:
"""Base Interpreter"""
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter(BaseInterpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
@staticmethod
def interpret(action, **kwargs):
"""interpret method
Parameters
----------
action :
rl agent action
Returns
-------
qlib orders
"""
raise NotImplementedError("interpret is not implemented!")
class StateInterpreter(BaseInterpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
@staticmethod
def interpret(execute_result, **kwargs):
"""interpret method
Parameters
----------
execute_result :
qlib execution result
Returns
----------
rl env state
"""
raise NotImplementedError("interpret is not implemented!")

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
from typing import List, Union
from ..model.base import BaseModel
@@ -14,7 +14,7 @@ from ..rl.interpreter import ActionInterpreter, StateInterpreter
class BaseStrategy(BaseTradeCalendar):
"""Base strategy"""
"""Base strategy for trading"""
def generate_order_list(self, execute_state):
"""Generate order list in each trading bar"""
@@ -22,13 +22,13 @@ class BaseStrategy(BaseTradeCalendar):
class RuleStrategy(BaseStrategy):
"""Trading strategy with rules"""
"""Rule-based Trading strategy"""
pass
class ModelStrategy(BaseStrategy):
"""Trading Strategy by using Model to make predictions"""
"""Model-based trading strategy, use model to make predictions for trading"""
def __init__(
self,
@@ -57,7 +57,7 @@ class ModelStrategy(BaseStrategy):
def _update_model(self):
"""
Update model in each bar when using online data as the following steps:
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
@@ -66,7 +66,7 @@ class ModelStrategy(BaseStrategy):
class RLStrategy(BaseStrategy):
"""RL-based Strategy"""
"""RL-based strategy"""
def __init__(
self,

View File

@@ -335,10 +335,10 @@ class PortAnaRecord(RecordTemp):
report_normal, _ = report_dict.get(self.risk_analysis_freq)
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"], self.risk_analysis_freq
report_normal["return"] - report_normal["bench"], freq=self.risk_analysis_freq
)
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=self.risk_analysis_freq
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics