From 07eaada31e670e9322febb8bf7b269eb76fb020a Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 13 May 2021 00:33:57 +0800 Subject: [PATCH] fix comments --- examples/multi_level_trading/workflow.py | 1 - qlib/contrib/backtest/__init__.py | 41 +++--- qlib/contrib/backtest/account.py | 97 +++++++------ qlib/contrib/backtest/backtest.py | 7 +- qlib/contrib/backtest/executor.py | 176 +++++++++++++---------- qlib/contrib/backtest/faculty.py | 28 ++++ qlib/contrib/strategy/cost_control.py | 23 ++- qlib/contrib/strategy/model_strategy.py | 24 +--- qlib/contrib/strategy/order_generator.py | 4 +- qlib/contrib/strategy/rule_strategy.py | 30 ++-- qlib/rl/env.py | 2 + qlib/rl/interpreter.py | 30 ++++ qlib/strategy/base.py | 12 +- qlib/workflow/record_temp.py | 4 +- 14 files changed, 294 insertions(+), 185 deletions(-) create mode 100644 qlib/contrib/backtest/faculty.py diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 9b0e6dc77..77689b3f7 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -122,7 +122,6 @@ if __name__ == "__main__": "benchmark": benchmark, "exchange_kwargs": { "freq": "day", - "verbose": False, "limit_threshold": 0.095, "deal_price": "close", "open_cost": 0.0005, diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index c8114d852..8cfbf9674 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -1,16 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +from .account import Account from .exchange import Exchange from .executor import BaseExecutor from .backtest import backtest as backtest_func -import inspect + from ...strategy.base import BaseStrategy from ...utils import init_instance_by_config from ...log import get_module_logger from ...config import C +from .faculty import common_faculty + logger = get_module_logger("backtest caller") @@ -28,7 +30,6 @@ def get_exchange( trade_unit=None, limit_threshold=None, deal_price=None, - shift=1, ): """get_exchange @@ -88,28 +89,26 @@ def get_exchange( return init_instance_by_config(exchange, accept_types=Exchange) -def setup_exchange(root_instance, trade_exchange=None, force=False): - if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: - if force: - root_instance.reset(trade_exchange=trade_exchange) - else: - if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None: - root_instance.reset(trade_exchange=trade_exchange) - if hasattr(root_instance, "sub_env"): - setup_exchange(root_instance.sub_env, trade_exchange) - if hasattr(root_instance, "sub_strategy"): - setup_exchange(root_instance.sub_strategy, trade_exchange) +def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}): + trade_account = Account( + init_cash=account, + benchmark_config={ + "benchmark": benchmark, + "start_time": start_time, + "end_time": end_time, + }, + ) + trade_exchange = get_exchange(**exchange_kwargs) + + common_faculty.update( + trade_account=trade_account, + trade_exchange=trade_exchange, + ) -def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}): trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) trade_env = init_instance_by_config(env, accept_types=BaseExecutor) - trade_exchange = get_exchange(**exchange_kwargs) - - setup_exchange(trade_env, trade_exchange) - setup_exchange(trade_strategy, trade_exchange) - - report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account) + report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env) return report_dict diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index 7e37c1093..5e2e03ea0 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -30,48 +30,53 @@ rtn & earning in the Account class Account: - def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None): - self.init_vars(init_cash, benchmark, start_time, end_time) + def __init__(self, init_cash, freq: str = "day", benchmark_config: dict = {}): + self.init_vars(init_cash, freq, benchmark_config) - def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None): + def init_vars(self, init_cash, freq: str, benchmark_config: dict): """ Parameters ---------- - - benchmark: str/list/pd.Series - `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. - example: - print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) - 2017-01-04 0.011693 - 2017-01-05 0.000721 - 2017-01-06 -0.004322 - 2017-01-09 0.006874 - 2017-01-10 -0.003350 - `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. - `benchmark` is str, will use the daily change as the 'bench'. - benchmark code, default is SH000905 CSI500 + freq : str + frequency of trading bar, used for updating hold count of trading bar + benchmark_config : dict + config of benchmark, may including the following arguments: + - benchmark : Union[str, list, pd.Series] + - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. + example: + print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) + 2017-01-04 0.011693 + 2017-01-05 0.000721 + 2017-01-06 -0.004322 + 2017-01-09 0.006874 + 2017-01-10 -0.003350 + - If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. + - If `benchmark` is str, will use the daily change as the 'bench'. + benchmark code, default is SH000300 CSI300 + - start_time : Union[str, pd.Timestamp], optional + - If `benchmark` is pd.Series, it will be ignored + - Else, it represent start time of benchmark, by default None + - end_time : Union[str, pd.Timestamp], optional + - If `benchmark` is pd.Series, it will be ignored + - Else, it represent end time of benchmark, by default None """ # init cash self.init_cash = init_cash - self.benchmark = benchmark - self.start_time = start_time - self.end_time = end_time self.freq = freq + self.benchmark_config = benchmark_config + self.bench = self._cal_benchmark(benchmark_config, freq) self.current = Position(cash=init_cash) - self.positions = {} - self.rtn = 0 - self.ct = 0 - self.to = 0 - self.val = 0 - self.earning = 0 - self.report = Report() - if freq and benchmark: - self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq) + self._reset_report() - def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None): + def _cal_benchmark(self, benchmark_config, freq): + benchmark = benchmark_config.get("benchmark", "SH000300") if isinstance(benchmark, pd.Series): return benchmark else: + start_time = benchmark_config.get("start_time", None) + end_time = benchmark_config.get("end_time", None) + if freq is None: raise ValueError("benchmark freq can't be None!") _codes = benchmark if isinstance(benchmark, list) else [benchmark] @@ -100,19 +105,25 @@ class Account: _ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) return 0 if _ret is None else _ret - def reset(self, benchmark=None, freq=None, **kwargs): - if benchmark: - self.benchmark = benchmark - if freq: + def _reset_freq(self, freq): + """reset frequency""" + if freq != self.freq: self.freq = freq - if self.freq and self.benchmark and (freq or benchmark): - self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq) + self.bench = self._cal_benchmark(self.benchmark_config, self.freq) - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) - else: - warnings.warn(f"reser error, attribute {k} is not found!") + def _reset_report(self): + self.report = Report() + self.positions = {} + self.rtn = 0 + self.ct = 0 + self.to = 0 + self.val = 0 + self.earning = 0 + + def reset(self, freq=None, init_report: bool = False): + self._reset_freq(freq) + if init_report: + self._reset_report() def get_positions(self): return self.positions @@ -155,7 +166,10 @@ class Account: self.current.update_order(order, trade_val, cost, trade_price) self.update_state_from_order(order, trade_val, cost, trade_price) - def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report): + def update_bar_count(self): + self.current.add_count_all(bar=self.freq) + + def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange): """ start_time: pd.TimeStamp end_time: pd.TimeStamp @@ -171,9 +185,6 @@ class Account: :return: None """ # update price for stock in the position and the profit from changed_price - self.current.add_count_all(bar=self.freq) - if update_report is None: - return stock_list = self.current.get_stock_list() for code in stock_list: # if suspend, no new price to be updated, profit is 0 diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index d5f92ebae..73785c771 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -1,13 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .account import Account +def backtest(start_time, end_time, trade_strategy, trade_env): -def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account): - - trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time) - trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) + trade_env.reset(start_time=start_time, end_time=end_time) trade_strategy.reset(start_time=start_time, end_time=end_time) _execute_state = trade_env.get_init_state() diff --git a/qlib/contrib/backtest/executor.py b/qlib/contrib/backtest/executor.py index 935af7361..65d9cfaea 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/contrib/backtest/executor.py @@ -1,18 +1,25 @@ import copy import warnings import pandas as pd -from typing import Tuple, List, Union, Optional, Callable +from typing import Union from ...data.data import Cal -from ...strategy.base import BaseStrategy + from ...utils import init_instance_by_config from ...utils.sample import get_sample_freq_calendar, parse_freq -from .report import Report + + from .order import Order from .account import Account from .exchange import Exchange +from .faculty import common_faculty class BaseTradeCalendar: + """ + Base class providing trading calendar + - BaseStrategy and BaseExecutor should inherited from this class + """ + def __init__( self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None ): @@ -30,16 +37,13 @@ class BaseTradeCalendar: """ self.step_bar = step_bar + self.start_time = pd.Timestamp(start_time) + self.end_time = pd.Timestamp(end_time) self.reset(start_time=start_time, end_time=end_time) def _reset_trade_calendar(self, start_time, end_time): - if not start_time and not end_time: - return - if start_time: - self.start_time = pd.Timestamp(start_time) - if end_time: - self.end_time = pd.Timestamp(end_time) - if self.start_time and self.end_time: + """reset trade calendar""" + if start_time and end_time: _calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar) self.calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index( @@ -50,17 +54,19 @@ class BaseTradeCalendar: self.trade_len = _end_index - _start_index + 1 self.trade_index = 0 else: - raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") + raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.") - def reset(self, start_time=None, end_time=None, **kwargs): - if start_time or end_time: - self._reset_trade_calendar(start_time=start_time, end_time=end_time) + def reset(self, start_time=None, end_time=None): + """ + Reset start\end time of trading, and reset trading calendar + """ - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) - else: - warnings.warn(f"reser error, attribute {k} is not found!") + if start_time: + self.start_time = pd.Timestamp(start_time) + if end_time: + self.end_time = pd.Timestamp(end_time) + if self.start_time and self.end_time and (start_time or end_time): + self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time) def _get_calendar_time(self, trade_index=1, shift=0): trade_index = trade_index - shift @@ -87,6 +93,7 @@ class BaseExecutor(BaseTradeCalendar): trade_account: Account = None, generate_report: bool = False, verbose: bool = False, + track_data: bool = False, **kwargs, ): """ @@ -94,23 +101,30 @@ class BaseExecutor(BaseTradeCalendar): ---------- trade_account : Account, optional trade account for trading, by default None - If `trade_account` is None, it must be reset before trading + - If `trade_account` is None, self.trade_account will be set with common_faculty generate_report : bool, optional whether to generate report, by default False verbose : bool, optional - whether to print log, by default False + whether to print trading info, by default False + track_data : bool, optional + whether to generate order_list, will be used when making data for multi-level training + - If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data` + - Else, `order_list` will not be generated """ - super(BaseExecutor, self).__init__( - step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs - ) + super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs) + self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account) + self.trade_account.reset(freq=self.step_bar, init_report=True) self.generate_report = generate_report self.verbose = verbose + self.track_data = track_data - def reset(self, trade_account=None, **kwargs): + def reset(self, track_data: bool = None, **kwargs): + """ + Reset `track_data`, will be used when making data for multi-level training + """ super(BaseExecutor, self).reset(**kwargs) - if trade_account: - self.trade_account = trade_account - self.trade_account.reset(freq=self.step_bar, report=Report(), positions={}) + if track_data is not None: + self.track_data = track_data def get_init_state(self): init_state = {"current": self.trade_account.current} @@ -127,6 +141,8 @@ class BaseExecutor(BaseTradeCalendar): class SplitExecutor(BaseExecutor): + from ...strategy.base import BaseStrategy + def __init__( self, step_bar: str, @@ -138,6 +154,7 @@ class SplitExecutor(BaseExecutor): trade_exchange: Exchange = None, generate_report: bool = False, verbose: bool = False, + track_data: bool = False, **kwargs, ): """ @@ -155,40 +172,55 @@ class SplitExecutor(BaseExecutor): start_time=start_time, end_time=end_time, trade_account=trade_account, - trade_exchange=trade_exchange, generate_report=generate_report, verbose=verbose, + track_data=track_data, **kwargs, ) + if generate_report: + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor) - self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy) - def reset(self, trade_account=None, trade_exchange=None, **kwargs): + self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy) - super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs) - if trade_account: - self.sub_env.reset(trade_account=copy.copy(trade_account)) - if trade_exchange: - self.trade_exchange = trade_exchange - - def execute(self, order_list): - super(SplitExecutor, self).step() + def _init_sub_trading(self, order_list): trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) - _execute_state = self.sub_env.get_init_state() - while not self.sub_env.finished(): - _order_list = self.sub_strategy.generate_order_list(_execute_state) - _execute_state = self.sub_env.execute(order_list=_order_list) + sub_execute_state = self.sub_env.get_init_state() + return sub_execute_state - self.trade_account.update_bar_end( - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, - trade_exchange=self.trade_exchange, - update_report=self.generate_report, - ) - _execute_state = {"current": self.trade_account.current} - return _execute_state + def _update_trade_account(self): + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) + self.trade_account.update_bar_count() + if self.generate_report: + self.trade_account.update_bar_report( + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + trade_exchange=self.trade_exchange, + ) + + def execute(self, order_list): + super(SplitExecutor, self).step() + self._init_sub_trading(order_list) + sub_execute_state = self.sub_env.get_init_state() + while not self.sub_env.finished(): + _order_list = self.sub_strategy.generate_order_list(sub_execute_state) + sub_execute_state = self.sub_env.execute(order_list=_order_list) + self._update_trade_account() + return {"current": self.trade_account.current} + + def get_data(self, order_list): + if self.track_data: + yield order_list + super(SplitExecutor, self).step() + self._init_sub_trading(order_list) + sub_execute_state = self.sub_env.get_init_state() + while not self.sub_env.finished(): + _order_list = self.sub_strategy.generate_order_list(sub_execute_state) + sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list) + self._update_trade_account() + return {"current": self.trade_account.current} def get_report(self): sub_env_report_dict = self.sub_env.get_report() @@ -203,13 +235,14 @@ class SplitExecutor(BaseExecutor): class SimulatorExecutor(BaseExecutor): def __init__( self, - step_bar, - start_time=None, - end_time=None, - trade_account=None, - trade_exchange=None, - generate_report=False, - verbose=False, + step_bar: str, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + trade_account: Account = None, + trade_exchange: Exchange = None, + generate_report: bool = False, + verbose: bool = False, + track_data: bool = False, **kwargs, ): """ @@ -223,16 +256,12 @@ class SimulatorExecutor(BaseExecutor): start_time=start_time, end_time=end_time, trade_account=trade_account, - trade_exchange=trade_exchange, generate_report=generate_report, verbose=verbose, + track_data=track_data, **kwargs, ) - - def reset(self, trade_exchange=None, **kwargs): - super(SimulatorExecutor, self).reset(**kwargs) - if trade_exchange: - self.trade_exchange = trade_exchange + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange def execute(self, order_list): super(SimulatorExecutor, self).step() @@ -276,14 +305,17 @@ class SimulatorExecutor(BaseExecutor): print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id)) # do nothing pass - self.trade_account.update_bar_end( - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, - trade_exchange=self.trade_exchange, - update_report=self.generate_report, - ) - _execute_state = {"current": self.trade_account.current, "trade_info": trade_info} - return _execute_state + + self.trade_account.update_bar_count() + + if self.generate_report: + self.trade_account.update_bar_report( + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + trade_exchange=self.trade_exchange, + ) + + return {"current": self.trade_account.current, "trade_info": trade_info} def get_report(self): if self.generate_report: diff --git a/qlib/contrib/backtest/faculty.py b/qlib/contrib/backtest/faculty.py new file mode 100644 index 000000000..34ad14cbc --- /dev/null +++ b/qlib/contrib/backtest/faculty.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class Faculty: + def __init__(self): + self.__dict__["_faculty"] = dict() + + def __getitem__(self, key): + return self.__dict__["_faculty"][key] + + def __getattr__(self, attr): + if attr in self.__dict__["_faculty"]: + return self.__dict__["_faculty"][attr] + + raise AttributeError(f"No such {attr} in self._faculty") + + def __setitem__(self, key, value): + self.__dict__["_faculty"][key] = value + + def __setattr__(self, attr, value): + self.__dict__["_faculty"][attr] = value + + def update(self, *args, **kwargs): + self.__dict__["_faculty"].update(*args, **kwargs) + + +common_faculty = Faculty() diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index 111cc276a..8b3e3db18 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -2,12 +2,27 @@ # Licensed under the MIT License. +from .order_generator import OrderGenWInteract from .model_strategy import WeightStrategyBase import copy class SoftTopkStrategy(WeightStrategyBase): - def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"): + def __init__( + self, + step_bar, + model, + dataset, + topk, + start_time=None, + end_time=None, + order_generator_cls_or_obj=OrderGenWInteract, + trade_exchange=None, + max_sold_weight=1.0, + risk_degree=0.95, + buy_method="first_fill", + **kwargs, + ): """Parameter topk : int top-N stocks to buy @@ -17,13 +32,15 @@ class SoftTopkStrategy(WeightStrategyBase): rank_fill: assign the weight stocks that rank high first(1/topk max) average_fill: assign the weight to the stocks rank high averagely. """ - super().__init__() + super(SoftTopkStrategy, self).__init__( + step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange + ) self.topk = topk self.max_sold_weight = max_sold_weight self.risk_degree = risk_degree self.buy_method = buy_method - def get_risk_degree(self, trade_index): + def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. Dynamically risk_degree will result in Market timing diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 1fc1bf070..b3bb33a88 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -6,6 +6,7 @@ import pandas as pd from ...utils.sample import sample_feature from ...strategy.base import ModelStrategy from ..backtest.order import Order +from ..backtest.faculty import common_faculty from .order_generator import OrderGenWInteract @@ -50,9 +51,8 @@ class TopkDropoutStrategy(ModelStrategy): else: strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ - super(TopkDropoutStrategy, self).__init__( - step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs - ) + super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs) + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange self.topk = topk self.n_drop = n_drop self.method_sell = method_sell @@ -61,11 +61,6 @@ class TopkDropoutStrategy(ModelStrategy): self.hold_thresh = hold_thresh self.only_tradable = only_tradable - def reset(self, trade_exchange=None, **kwargs): - super(TopkDropoutStrategy, self).reset(**kwargs) - if trade_exchange: - self.trade_exchange = trade_exchange - def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -164,7 +159,7 @@ class TopkDropoutStrategy(ModelStrategy): # Get the stock list we really want to buy buy = today[: len(sell) + self.topk - len(last)] - print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last)) + # print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last)) # print("flag", len(sell), len(buy), self.topk, len(last)) for code in current_stock_list: if not self.trade_exchange.is_stock_tradable( @@ -242,20 +237,13 @@ class WeightStrategyBase(ModelStrategy): trade_exchange=None, **kwargs, ): - super(WeightStrategyBase, self).__init__( - step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs - ) - + super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs) + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj - def reset(self, trade_exchange=None, **kwargs): - super(WeightStrategyBase, self).reset(**kwargs) - if trade_exchange: - self.trade_exchange = trade_exchange - def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index 93bf7b2fe..db2c1de0d 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -173,7 +173,9 @@ class OrderGenWOInteract(OrderGenerator): stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time ): amount_dict[stock_id] = ( - risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date) + risk_total_value + * target_weight_position[stock_id] + / trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time) ) elif stock_id in current_stock: amount_dict[stock_id] = ( diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 073f513c7..240a61595 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -9,6 +9,7 @@ from ...data.data import D from ...data.dataset.utils import convert_index_format from ...strategy.base import RuleStrategy, OrderEnhancement from ..backtest.order import Order +from ..backtest.faculty import common_faculty class TWAPStrategy(RuleStrategy, OrderEnhancement): @@ -18,16 +19,17 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): start_time=None, end_time=None, trade_exchange=None, + trade_order_list=[], **kwargs, ): - super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs) + super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + self.trade_order_list = trade_order_list - def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): + def reset(self, trade_order_list: list = None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) OrderEnhancement.reset(self, trade_order_list=trade_order_list) - if trade_exchange: - self.trade_exchange = trade_exchange - if trade_order_list: + if trade_order_list is not None: self.trade_amount = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount @@ -82,15 +84,16 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): start_time=None, end_time=None, trade_exchange=None, + trade_order_list=[], **kwargs, ): - super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs) + super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs) + self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + self.trade_order_list = trade_order_list - def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): + def reset(self, trade_order_list=None, **kwargs): super(SBBStrategyBase, self).reset(**kwargs) OrderEnhancement.reset(self, trade_order_list=trade_order_list) - if trade_exchange: - self.trade_exchange = trade_exchange if trade_order_list is not None: self.trade_trend = {} self.trade_amount = {} @@ -217,11 +220,12 @@ class SBBStrategyEMA(SBBStrategyBase): start_time=None, end_time=None, trade_exchange=None, + trade_order_list=[], instruments="csi300", freq="day", **kwargs, ): - super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs) + super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs) if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") self.instruments = "all" @@ -229,9 +233,9 @@ class SBBStrategyEMA(SBBStrategyBase): self.instruments = D.instruments(instruments) self.freq = freq - def reset(self, start_time=None, end_time=None, **kwargs): - super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs) - if self.start_time and self.end_time: + def _reset_trade_calendar(self, start_time=None, end_time=None): + super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time) + if start_time and end_time: fields = ["EMA($close, 10)-EMA($close, 20)"] signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) signal_df = D.features( diff --git a/qlib/rl/env.py b/qlib/rl/env.py index 9424aafab..fae17918d 100644 --- a/qlib/rl/env.py +++ b/qlib/rl/env.py @@ -7,6 +7,8 @@ from ..contrib.backtest.executor import BaseExecutor class BaseRLEnv: + """Base environment for reinforcement learning""" + def reset(self, **kwargs): raise NotImplementedError("reset is not implemented!") diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index bad337f72..3c94aac09 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,18 +3,48 @@ class BaseInterpreter: + """Base Interpreter""" + @staticmethod def interpret(**kwargs): raise NotImplementedError("interpret is not implemented!") class ActionInterpreter(BaseInterpreter): + """Action Interpreter that interpret rl agent action into qlib orders""" + @staticmethod def interpret(action, **kwargs): + """interpret method + + Parameters + ---------- + action : + rl agent action + + Returns + ------- + qlib orders + + """ + raise NotImplementedError("interpret is not implemented!") class StateInterpreter(BaseInterpreter): + """State Interpreter that interpret execution result of qlib executor into rl env state""" + @staticmethod def interpret(execute_result, **kwargs): + """interpret method + + Parameters + ---------- + execute_result : + qlib execution result + + Returns + ---------- + rl env state + """ raise NotImplementedError("interpret is not implemented!") diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index a5e7210bd..5534998e9 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import pandas as pd -from typing import Tuple, List, Union, Optional, Callable +from typing import List, Union from ..model.base import BaseModel @@ -14,7 +14,7 @@ from ..rl.interpreter import ActionInterpreter, StateInterpreter class BaseStrategy(BaseTradeCalendar): - """Base strategy""" + """Base strategy for trading""" def generate_order_list(self, execute_state): """Generate order list in each trading bar""" @@ -22,13 +22,13 @@ class BaseStrategy(BaseTradeCalendar): class RuleStrategy(BaseStrategy): - """Trading strategy with rules""" + """Rule-based Trading strategy""" pass class ModelStrategy(BaseStrategy): - """Trading Strategy by using Model to make predictions""" + """Model-based trading strategy, use model to make predictions for trading""" def __init__( self, @@ -57,7 +57,7 @@ class ModelStrategy(BaseStrategy): def _update_model(self): """ - Update model in each bar when using online data as the following steps: + When using online data, pdate model in each bar as the following steps: - update dataset with online data, the dataset should support online update - make the latest prediction scores of the new bar - update the pred score into the latest prediction @@ -66,7 +66,7 @@ class ModelStrategy(BaseStrategy): class RLStrategy(BaseStrategy): - """RL-based Strategy""" + """RL-based strategy""" def __init__( self, diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 8a8bde7ef..6bb6341f0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -335,10 +335,10 @@ class PortAnaRecord(RecordTemp): report_normal, _ = report_dict.get(self.risk_analysis_freq) analysis = dict() analysis["excess_return_without_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"], self.risk_analysis_freq + report_normal["return"] - report_normal["bench"], freq=self.risk_analysis_freq ) analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq + report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=self.risk_analysis_freq ) analysis_df = pd.concat(analysis) # type: pd.DataFrame # log metrics