From a109df3f467841eb32952ef924c19fc8373097bd Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 30 Apr 2021 01:06:05 +0800 Subject: [PATCH] fix bug in recorder --- examples/highfreq/backtest/workflow.py | 29 ++++++++++++------ qlib/contrib/backtest/__init__.py | 2 +- qlib/contrib/backtest/account.py | 19 +++++++----- qlib/contrib/backtest/backtest.py | 6 +--- qlib/contrib/backtest/env.py | 40 +++++++++++++------------ qlib/contrib/evaluate.py | 2 +- qlib/contrib/strategy/model_strategy.py | 35 ++-------------------- qlib/workflow/record_temp.py | 13 +++----- 8 files changed, 63 insertions(+), 83 deletions(-) diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index d031d40f2..a4d163ce5 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -10,6 +10,8 @@ from qlib.config import REG_CN from qlib.contrib.strategy import TopkDropoutStrategy from qlib.contrib.backtest import backtest from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.workflow import R +from qlib.workflow.record_temp import PortAnaRecord from qlib.tests.data import GetData if __name__ == "__main__": @@ -78,7 +80,7 @@ if __name__ == "__main__": trade_start_time = "2017-01-31" trade_end_time = "2018-01-31" - backtest_config = { + port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy", @@ -101,6 +103,7 @@ if __name__ == "__main__": "kwargs": { "step_bar": "day", "verbose": True, + "generate_report": True, }, }, "sub_strategy": { @@ -128,11 +131,19 @@ if __name__ == "__main__": }, } - report_dict = backtest( - start_time=trade_start_time, - end_time=trade_end_time, - **backtest_config, - account=1e8, - deal_price="$close", - verbose=False, - ) + #report_dict = backtest( + # start_time=trade_start_time, + # end_time=trade_end_time, + # **backtest_config, + # account=1e8, + # benchmark=benchmark, + # deal_price="$close", + # verbose=False, + #) + + with R.start(experiment_name="highfreq_backtest"): + # backtest. If users want to use backtest based on their own prediction, + # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. + recorder = R.get_recorder() + par = PortAnaRecord(recorder, port_analysis_config, 1) + par.generate() \ No newline at end of file diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index 21d3913e5..dacbdfefc 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -118,7 +118,7 @@ def setup_exchange(root_instance, trade_exchange=None, force=False): setup_exchange(root_instance.sub_strategy, trade_exchange) -def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs): +def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs): trade_strategy = init_instance_by_config(strategy) trade_env = init_env_instance_by_config(env) diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index ad88e274a..5a35ffc08 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -8,6 +8,7 @@ import pandas as pd from .position import Position from .report import Report from .order import Order +from ...data import D from ...utils import parse_freq, sample_feature @@ -95,7 +96,8 @@ class Account: def cal_change(x): return x.prod() - 1 - return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) + _ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) + return 0 if _ret is None else _ret def reset(self, benchmark=None, freq=None, **kwargs): if benchmark: @@ -105,9 +107,9 @@ class Account: if self.freq and self.benchmark and (freq or benchmark): self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq) - for k, v in kwargs: - if hasattr(k): - setattr(k, v) + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) def get_positions(self): return self.positions @@ -150,7 +152,7 @@ class Account: self.current.update_order(order, trade_val, cost, trade_price) self.update_state_from_order(order, trade_val, cost, trade_price) - def update_report(self, trade_start_time, trade_end_time, trade_exchange): + def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report): """ start_time: pd.TimeStamp end_time: pd.TimeStamp @@ -166,6 +168,9 @@ class Account: :return: None """ # update price for stock in the position and the profit from changed_price + self.current.add_count_all(bar=self.freq) + if update_report is None: + return stock_list = self.current.get_stock_list() for code in stock_list: # if suspend, no new price to be updated, profit is 0 @@ -174,7 +179,7 @@ class Account: bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) self.current.update_stock_price(stock_id=code, price=bar_close) # update holding day count - self.current.add_count_all(bar=self.freq) + # update value self.val = self.current.calculate_value() # update earning @@ -212,7 +217,7 @@ class Account: self.positions[trade_start_time] = copy.deepcopy(self.current) # finish today's updation - # reset the daily variables + # reset the bar variables self.rtn = 0 self.ct = 0 self.to = 0 diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index d6fcb509d..d67d6782b 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -19,8 +19,4 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account _order_list = trade_strategy.generate_order_list(**trade_state) trade_state, trade_info = trade_env.execute(_order_list) - report_df = trade_account.report.generate_report_dataframe() - positions = trade_account.get_positions() - report_dict = {"report_df": report_df, "positions": positions} - - return report_dict + return trade_env.get_report() diff --git a/qlib/contrib/backtest/env.py b/qlib/contrib/backtest/env.py index ade5caf24..ea2618977 100644 --- a/qlib/contrib/backtest/env.py +++ b/qlib/contrib/backtest/env.py @@ -42,7 +42,7 @@ class BaseTradeCalendar: if start_time or end_time: self._reset_trade_calendar(start_time=start_time, end_time=end_time) - for k, v in kwargs: + for k, v in kwargs.items(): if hasattr(self, k): setattr(self, k, v) @@ -52,7 +52,7 @@ class BaseTradeCalendar: return self.calendar[calendar_index - 1], self.calendar[calendar_index] def finished(self): - return self.trade_index >= self.trade_len + return self.trade_index >= self.trade_len - 1 def step(self): self.trade_index = self.trade_index + 1 @@ -71,11 +71,11 @@ class BaseEnv(BaseTradeCalendar): start_time=None, end_time=None, trade_account=None, - update_report=False, + generate_report=False, verbose=False, **kwargs, ): - self.generate_report = update_report + self.generate_report = generate_report self.verbose = verbose super(BaseEnv, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs @@ -110,7 +110,8 @@ class SplitEnv(BaseEnv): start_time=None, end_time=None, trade_account=None, - update_report=False, + trade_exchange=None, + generate_report=False, verbose=False, **kwargs, ): @@ -121,15 +122,18 @@ class SplitEnv(BaseEnv): start_time=start_time, end_time=end_time, trade_account=trade_account, - update_report=update_report, + trade_exchange=trade_exchange, + generate_report=generate_report, verbose=verbose, **kwargs, ) - def reset(self, trade_account=None, **kwargs): + def reset(self, trade_account=None, trade_exchange=None, **kwargs): super(SplitEnv, self).reset(trade_account=trade_account, **kwargs) if trade_account: self.sub_env.reset(trade_account=copy.copy(trade_account)) + if trade_exchange: + self.trade_exchange = trade_exchange def execute(self, order_list, **kwargs): if self.finished(): @@ -146,10 +150,9 @@ class SplitEnv(BaseEnv): _order_list = self.sub_strategy.generate_order_list(**trade_state) trade_state, trade_info = self.sub_env.execute(order_list=_order_list) - if self.generate_report: - self.trade_account.update_report( - trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange - ) + self.trade_account.update_bar_end( + trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report + ) _obs = {"current": self.trade_account.current} _info = {} return _obs, _info @@ -157,7 +160,7 @@ class SplitEnv(BaseEnv): def get_report(self): _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None _positions = self.trade_account.get_positions() if self.generate_report else None - return [(_report, _positions), *sub_env.get_report()] + return [(_report, _positions), *self.sub_env.get_report()] class SimulatorEnv(BaseEnv): @@ -168,7 +171,7 @@ class SimulatorEnv(BaseEnv): end_time=None, trade_account=None, trade_exchange=None, - update_report=False, + generate_report=False, verbose=False, **kwargs, ): @@ -178,7 +181,7 @@ class SimulatorEnv(BaseEnv): end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, - update_report=update_report, + generate_report=generate_report, verbose=verbose, **kwargs, ) @@ -231,10 +234,9 @@ class SimulatorEnv(BaseEnv): print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id)) # do nothing pass - if self.generate_report: - self.trade_account.update_report( - trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange - ) + self.trade_account.update_bar_end( + trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report + ) _obs = {"current": self.trade_account.current} _info = {"trade_info": trade_info} return _obs, _info @@ -242,4 +244,4 @@ class SimulatorEnv(BaseEnv): def get_report(self): _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None _positions = self.trade_account.get_positions() if self.generate_report else None - return [{"report": _report, "positions": _positions}] + return [(_report, _positions)] diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 4aa5b5515..91cfc1d89 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -9,7 +9,7 @@ import pandas as pd import warnings from ..log import get_module_logger from .backtest import get_exchange, backtest as backtest_func -from .backtest.backtest import get_date_range +from ..utils import get_date_range from ..data import D from ..config import C diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 95280dc2f..0bd0b9e0c 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -23,7 +23,6 @@ class TopkDropoutStrategy(ModelStrategy): method_sell="bottom", method_buy="top", risk_degree=0.95, - thresh=1, hold_thresh=1, only_tradable=False, **kwargs, @@ -41,11 +40,9 @@ class TopkDropoutStrategy(ModelStrategy): dropout method_buy, random/top. risk_degree : float position percentage of total value. - thresh : int - minimun holding days since last buy singal of the stock. hold_thresh : int minimum holding days - before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh. + before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh. only_tradable : bool will the strategy only consider the tradable stock when buying and selling. if only_tradable: @@ -61,10 +58,6 @@ class TopkDropoutStrategy(ModelStrategy): self.method_sell = method_sell self.method_buy = method_buy self.risk_degree = risk_degree - self.thresh = thresh - # self.stock_count['code'] will be the days the stock has been hold - # since last buy signal. This is designed for thresh - self.stock_count = {} self.hold_thresh = hold_thresh self.only_tradable = only_tradable @@ -170,10 +163,7 @@ class TopkDropoutStrategy(ModelStrategy): # Get the stock list we really want to buy buy = today[: len(sell) + self.topk - len(last)] - - # buy singal: if a stock falls into topk, it appear in the buy_sinal - buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index - + #print("flag", len(sell), len(buy), self.topk, len(last)) for code in current_stock_list: if not self.trade_exchange.is_stock_tradable( stock_id=code, start_time=trade_start_time, end_time=trade_end_time @@ -181,13 +171,7 @@ class TopkDropoutStrategy(ModelStrategy): continue if code in sell: # check hold limit - if ( - self.stock_count[code] < self.thresh - or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh - ): - # can not sell this code - # no buy signal, but the stock is kept - self.stock_count[code] += 1 + if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh: continue # sell order sell_amount = current_temp.get_stock_amount(code=code) @@ -207,18 +191,6 @@ class TopkDropoutStrategy(ModelStrategy): ) # update cash cash += trade_val - trade_cost - # sold - self.stock_count[code] = 0 - else: - # no buy signal, but the stock is kept - self.stock_count[code] += 1 - elif code in buy_signal: - # NOTE: This is different from the original version - # get new buy signal - # Only the stock fall in to topk will produce buy signal - self.stock_count[code] = 1 - else: - self.stock_count[code] += 1 # buy new stock # note the current has been changed current_stock_list = current_temp.get_stock_list() @@ -249,7 +221,6 @@ class TopkDropoutStrategy(ModelStrategy): factor=factor, ) buy_order_list.append(buy_order) - self.stock_count[code] = 1 return sell_order_list + buy_order_list diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index b7935ae08..546fb5a60 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -14,8 +14,9 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict +from ..strategy.base import BaseStrategy from ..contrib.eva.alpha import calc_ic, calc_long_short_return -from ..contrib.strategy.strategy import BaseStrategy + logger = get_module_logger("workflow", "INFO") @@ -212,7 +213,7 @@ class SigAnaRecord(SignalRecord): return paths -class PortAnaRecord(SignalRecord): +class PortAnaRecord(RecordTemp): """ This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class. @@ -243,16 +244,10 @@ class PortAnaRecord(SignalRecord): self.risk_analysis_dep = risk_analysis_dep def generate(self, **kwargs): - # check previously stored prediction results - try: - self.check(parent=True) # "Make sure the parent process is completed and store the data properly." - except FileExistsError: - super().generate() - # custom strategy and get backtest report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config) for report_dep, (report_normal, positions_normal) in enumerate(report_list): - if report_dict is None: + if report_normal is None: if self.risk_analysis_dep == report_dep: warnings.warn( f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`"