mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
fix comments
This commit is contained in:
@@ -122,7 +122,6 @@ if __name__ == "__main__":
|
||||
"benchmark": benchmark,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
import inspect
|
||||
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
from .faculty import common_faculty
|
||||
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
@@ -28,7 +30,6 @@ def get_exchange(
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
@@ -88,28 +89,26 @@ def get_exchange(
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
else:
|
||||
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(root_instance, "sub_env"):
|
||||
setup_exchange(root_instance.sub_env, trade_exchange)
|
||||
if hasattr(root_instance, "sub_strategy"):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
|
||||
trade_account = Account(
|
||||
init_cash=account,
|
||||
benchmark_config={
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
)
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_faculty.update(
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
)
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}):
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
|
||||
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
setup_exchange(trade_env, trade_exchange)
|
||||
setup_exchange(trade_strategy, trade_exchange)
|
||||
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env)
|
||||
|
||||
return report_dict
|
||||
|
||||
@@ -30,48 +30,53 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
|
||||
self.init_vars(init_cash, benchmark, start_time, end_time)
|
||||
def __init__(self, init_cash, freq: str = "day", benchmark_config: dict = {}):
|
||||
self.init_vars(init_cash, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None):
|
||||
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
- benchmark: str/list/pd.Series
|
||||
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
freq : str
|
||||
frequency of trading bar, used for updating hold count of trading bar
|
||||
benchmark_config : dict
|
||||
config of benchmark, may including the following arguments:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent start time of benchmark, by default None
|
||||
- end_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent end time of benchmark, by default None
|
||||
|
||||
"""
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.benchmark = benchmark
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(benchmark_config, freq)
|
||||
self.current = Position(cash=init_cash)
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.val = 0
|
||||
self.earning = 0
|
||||
self.report = Report()
|
||||
if freq and benchmark:
|
||||
self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq)
|
||||
self._reset_report()
|
||||
|
||||
def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None):
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", "SH000300")
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
start_time = benchmark_config.get("start_time", None)
|
||||
end_time = benchmark_config.get("end_time", None)
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
@@ -100,19 +105,25 @@ class Account:
|
||||
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0 if _ret is None else _ret
|
||||
|
||||
def reset(self, benchmark=None, freq=None, **kwargs):
|
||||
if benchmark:
|
||||
self.benchmark = benchmark
|
||||
if freq:
|
||||
def _reset_freq(self, freq):
|
||||
"""reset frequency"""
|
||||
if freq != self.freq:
|
||||
self.freq = freq
|
||||
if self.freq and self.benchmark and (freq or benchmark):
|
||||
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(f"reser error, attribute {k} is not found!")
|
||||
def _reset_report(self):
|
||||
self.report = Report()
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.val = 0
|
||||
self.earning = 0
|
||||
|
||||
def reset(self, freq=None, init_report: bool = False):
|
||||
self._reset_freq(freq)
|
||||
if init_report:
|
||||
self._reset_report()
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -155,7 +166,10 @@ class Account:
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report):
|
||||
def update_bar_count(self):
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
@@ -171,9 +185,6 @@ class Account:
|
||||
:return: None
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
if update_report is None:
|
||||
return
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .account import Account
|
||||
|
||||
def backtest(start_time, end_time, trade_strategy, trade_env):
|
||||
|
||||
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
|
||||
|
||||
trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time)
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_env.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
_execute_state = trade_env.get_init_state()
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
import copy
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
from typing import Union
|
||||
from ...data.data import Cal
|
||||
from ...strategy.base import BaseStrategy
|
||||
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.sample import get_sample_freq_calendar, parse_freq
|
||||
from .report import Report
|
||||
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .faculty import common_faculty
|
||||
|
||||
|
||||
class BaseTradeCalendar:
|
||||
"""
|
||||
Base class providing trading calendar
|
||||
- BaseStrategy and BaseExecutor should inherited from this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
@@ -30,16 +37,13 @@ class BaseTradeCalendar:
|
||||
"""
|
||||
|
||||
self.step_bar = step_bar
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
if not start_time and not end_time:
|
||||
return
|
||||
if start_time:
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
if end_time:
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
if self.start_time and self.end_time:
|
||||
"""reset trade calendar"""
|
||||
if start_time and end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
|
||||
self.calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(
|
||||
@@ -50,17 +54,19 @@ class BaseTradeCalendar:
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.")
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
def reset(self, start_time=None, end_time=None):
|
||||
"""
|
||||
Reset start\end time of trading, and reset trading calendar
|
||||
"""
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(f"reser error, attribute {k} is not found!")
|
||||
if start_time:
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
if end_time:
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
if self.start_time and self.end_time and (start_time or end_time):
|
||||
self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time)
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
@@ -87,6 +93,7 @@ class BaseExecutor(BaseTradeCalendar):
|
||||
trade_account: Account = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -94,23 +101,30 @@ class BaseExecutor(BaseTradeCalendar):
|
||||
----------
|
||||
trade_account : Account, optional
|
||||
trade account for trading, by default None
|
||||
If `trade_account` is None, it must be reset before trading
|
||||
- If `trade_account` is None, self.trade_account will be set with common_faculty
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
whether to print log, by default False
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate order_list, will be used when making data for multi-level training
|
||||
- If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data`
|
||||
- Else, `order_list` will not be generated
|
||||
"""
|
||||
super(BaseExecutor, self).__init__(
|
||||
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
|
||||
)
|
||||
super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs)
|
||||
self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account)
|
||||
self.trade_account.reset(freq=self.step_bar, init_report=True)
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
def reset(self, track_data: bool = None, **kwargs):
|
||||
"""
|
||||
Reset `track_data`, will be used when making data for multi-level training
|
||||
"""
|
||||
super(BaseExecutor, self).reset(**kwargs)
|
||||
if trade_account:
|
||||
self.trade_account = trade_account
|
||||
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
|
||||
if track_data is not None:
|
||||
self.track_data = track_data
|
||||
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self.trade_account.current}
|
||||
@@ -127,6 +141,8 @@ class BaseExecutor(BaseTradeCalendar):
|
||||
|
||||
|
||||
class SplitExecutor(BaseExecutor):
|
||||
from ...strategy.base import BaseStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
@@ -138,6 +154,7 @@ class SplitExecutor(BaseExecutor):
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -155,40 +172,55 @@ class SplitExecutor(BaseExecutor):
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
**kwargs,
|
||||
)
|
||||
if generate_report:
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
|
||||
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy)
|
||||
|
||||
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
|
||||
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy)
|
||||
|
||||
super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs)
|
||||
if trade_account:
|
||||
self.sub_env.reset(trade_account=copy.copy(trade_account))
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def execute(self, order_list):
|
||||
super(SplitExecutor, self).step()
|
||||
def _init_sub_trading(self, order_list):
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(_execute_state)
|
||||
_execute_state = self.sub_env.execute(order_list=_order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
return sub_execute_state
|
||||
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
update_report=self.generate_report,
|
||||
)
|
||||
_execute_state = {"current": self.trade_account.current}
|
||||
return _execute_state
|
||||
def _update_trade_account(self):
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.trade_account.update_bar_count()
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
def execute(self, order_list):
|
||||
super(SplitExecutor, self).step()
|
||||
self._init_sub_trading(order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
|
||||
sub_execute_state = self.sub_env.execute(order_list=_order_list)
|
||||
self._update_trade_account()
|
||||
return {"current": self.trade_account.current}
|
||||
|
||||
def get_data(self, order_list):
|
||||
if self.track_data:
|
||||
yield order_list
|
||||
super(SplitExecutor, self).step()
|
||||
self._init_sub_trading(order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
|
||||
sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list)
|
||||
self._update_trade_account()
|
||||
return {"current": self.trade_account.current}
|
||||
|
||||
def get_report(self):
|
||||
sub_env_report_dict = self.sub_env.get_report()
|
||||
@@ -203,13 +235,14 @@ class SplitExecutor(BaseExecutor):
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
step_bar: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -223,16 +256,12 @@ class SimulatorExecutor(BaseExecutor):
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(SimulatorExecutor, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
|
||||
def execute(self, order_list):
|
||||
super(SimulatorExecutor, self).step()
|
||||
@@ -276,14 +305,17 @@ class SimulatorExecutor(BaseExecutor):
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
update_report=self.generate_report,
|
||||
)
|
||||
_execute_state = {"current": self.trade_account.current, "trade_info": trade_info}
|
||||
return _execute_state
|
||||
|
||||
self.trade_account.update_bar_count()
|
||||
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
return {"current": self.trade_account.current, "trade_info": trade_info}
|
||||
|
||||
def get_report(self):
|
||||
if self.generate_report:
|
||||
|
||||
28
qlib/contrib/backtest/faculty.py
Normal file
28
qlib/contrib/backtest/faculty.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
class Faculty:
|
||||
def __init__(self):
|
||||
self.__dict__["_faculty"] = dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.__dict__["_faculty"][key]
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.__dict__["_faculty"]:
|
||||
return self.__dict__["_faculty"][attr]
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._faculty")
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_faculty"][key] = value
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
self.__dict__["_faculty"][attr] = value
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
self.__dict__["_faculty"].update(*args, **kwargs)
|
||||
|
||||
|
||||
common_faculty = Faculty()
|
||||
@@ -2,12 +2,27 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .model_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
class SoftTopkStrategy(WeightStrategyBase):
|
||||
def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
topk,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
max_sold_weight=1.0,
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
**kwargs,
|
||||
):
|
||||
"""Parameter
|
||||
topk : int
|
||||
top-N stocks to buy
|
||||
@@ -17,13 +32,15 @@ class SoftTopkStrategy(WeightStrategyBase):
|
||||
rank_fill: assign the weight stocks that rank high first(1/topk max)
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
"""
|
||||
super().__init__()
|
||||
super(SoftTopkStrategy, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange
|
||||
)
|
||||
self.topk = topk
|
||||
self.max_sold_weight = max_sold_weight
|
||||
self.risk_degree = risk_degree
|
||||
self.buy_method = buy_method
|
||||
|
||||
def get_risk_degree(self, trade_index):
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
|
||||
@@ -6,6 +6,7 @@ import pandas as pd
|
||||
from ...utils.sample import sample_feature
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ..backtest.order import Order
|
||||
from ..backtest.faculty import common_faculty
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
@@ -50,9 +51,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
@@ -61,11 +61,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(TopkDropoutStrategy, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -164,7 +159,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
# Get the stock list we really want to buy
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
|
||||
# print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
|
||||
# print("flag", len(sell), len(buy), self.topk, len(last))
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
@@ -242,20 +237,13 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_exchange=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
|
||||
super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(WeightStrategyBase, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
|
||||
@@ -173,7 +173,9 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
|
||||
risk_total_value
|
||||
* target_weight_position[stock_id]
|
||||
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
|
||||
)
|
||||
elif stock_id in current_stock:
|
||||
amount_dict[stock_id] = (
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import RuleStrategy, OrderEnhancement
|
||||
from ..backtest.order import Order
|
||||
from ..backtest.faculty import common_faculty
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
@@ -18,16 +19,17 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
|
||||
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
|
||||
def reset(self, trade_order_list: list = None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
if trade_order_list:
|
||||
if trade_order_list is not None:
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
@@ -82,15 +84,16 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
|
||||
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
super(SBBStrategyBase, self).reset(**kwargs)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
if trade_order_list is not None:
|
||||
self.trade_trend = {}
|
||||
self.trade_amount = {}
|
||||
@@ -217,11 +220,12 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
**kwargs,
|
||||
):
|
||||
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
|
||||
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
@@ -229,9 +233,9 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
if self.start_time and self.end_time:
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
if start_time and end_time:
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
signal_df = D.features(
|
||||
|
||||
@@ -7,6 +7,8 @@ from ..contrib.backtest.executor import BaseExecutor
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
"""Base environment for reinforcement learning"""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
raise NotImplementedError("reset is not implemented!")
|
||||
|
||||
|
||||
@@ -3,18 +3,48 @@
|
||||
|
||||
|
||||
class BaseInterpreter:
|
||||
"""Base Interpreter"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
rl agent action
|
||||
|
||||
Returns
|
||||
-------
|
||||
qlib orders
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(execute_result, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
qlib execution result
|
||||
|
||||
Returns
|
||||
----------
|
||||
rl env state
|
||||
"""
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
from ..model.base import BaseModel
|
||||
@@ -14,7 +14,7 @@ from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
|
||||
|
||||
class BaseStrategy(BaseTradeCalendar):
|
||||
"""Base strategy"""
|
||||
"""Base strategy for trading"""
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
"""Generate order list in each trading bar"""
|
||||
@@ -22,13 +22,13 @@ class BaseStrategy(BaseTradeCalendar):
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
"""Trading strategy with rules"""
|
||||
"""Rule-based Trading strategy"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Trading Strategy by using Model to make predictions"""
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,7 +57,7 @@ class ModelStrategy(BaseStrategy):
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
Update model in each bar when using online data as the following steps:
|
||||
When using online data, pdate model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
@@ -66,7 +66,7 @@ class ModelStrategy(BaseStrategy):
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""RL-based Strategy"""
|
||||
"""RL-based strategy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -335,10 +335,10 @@ class PortAnaRecord(RecordTemp):
|
||||
report_normal, _ = report_dict.get(self.risk_analysis_freq)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], self.risk_analysis_freq
|
||||
report_normal["return"] - report_normal["bench"], freq=self.risk_analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=self.risk_analysis_freq
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
|
||||
Reference in New Issue
Block a user