diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index bbe00ed5c..8e4f30c5f 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -10,7 +10,7 @@ from qlib.config import REG_CN from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R -from qlib.workflow.record_temp import PortAnaRecord +from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData if __name__ == "__main__": @@ -64,9 +64,9 @@ if __name__ == "__main__": "kwargs": data_handler_config, }, "segments": { - "train": ("2012-01-01", "2014-12-31"), + "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2018-01-31"), + "test": ("2017-01-01", "2020-08-01"), }, }, }, @@ -74,17 +74,16 @@ if __name__ == "__main__": # model initialization model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) - model.fit(dataset) - trade_start_time = "2017-01-31" - trade_end_time = "2018-01-31" + trade_start_time = "2017-01-01" + trade_end_time = "2020-08-01" port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy", "kwargs": { - "step_bar": "week", + "step_bar": "day", "model": model, "dataset": dataset, "topk": 50, @@ -92,28 +91,12 @@ if __name__ == "__main__": }, }, "env": { - "class": "SplitEnv", + "class": "SimulatorEnv", "module_path": "qlib.contrib.backtest.env", "kwargs": { - "step_bar": "week", - "sub_env": { - "class": "SimulatorEnv", - "module_path": "qlib.contrib.backtest.env", - "kwargs": { - "step_bar": "day", - "verbose": True, - "generate_report": True, - }, - }, - "sub_strategy": { - "class": "SBBStrategyEMA", - "module_path": "qlib.contrib.strategy.rule_strategy", - "kwargs": { - "step_bar": "day", - "freq": "day", - "instruments": "csi300", - }, - }, + "step_bar": "day", + "verbose": True, + "generate_report": True, }, }, "backtest": { @@ -129,9 +112,18 @@ if __name__ == "__main__": "min_cost": 5, }, } + with R.start(experiment_name="highfreq_backtest"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + R.save_objects(**{"params.pkl": model}) + + # prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() + # backtest. If users want to use backtest based on their own prediction, # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. - recorder = R.get_recorder() par = PortAnaRecord(recorder, port_analysis_config, "day") - par.generate() + par.generate() \ No newline at end of file diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index 88a695f8f..39fecbd88 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -94,7 +94,7 @@ class Account: def _sample_benchmark(self, bench, trade_start_time, trade_end_time): def cal_change(x): - return x.prod() - 1 + return (x + 1).prod() - 1 _ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) return 0 if _ret is None else _ret diff --git a/qlib/contrib/backtest/env.py b/qlib/contrib/backtest/env.py index f5c84169d..eb922cefd 100644 --- a/qlib/contrib/backtest/env.py +++ b/qlib/contrib/backtest/env.py @@ -49,7 +49,7 @@ class BaseTradeCalendar: def _get_calendar_time(self, trade_index=1, shift=0): trade_index = trade_index - shift calendar_index = self.start_index + trade_index - return self.calendar[calendar_index - 1], self.calendar[calendar_index] + return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1) def finished(self): return self.trade_index >= self.trade_len - 1 diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 4d471cf89..6899a10a5 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -51,7 +51,7 @@ class TopkDropoutStrategy(ModelStrategy): strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ super(TopkDropoutStrategy, self).__init__( - step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange + step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs ) self.topk = topk self.n_drop = n_drop diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 1acf55314..f69dee10d 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -11,16 +11,30 @@ from ..backtest.order import Order class TWAPStrategy(RuleStrategy, TradingEnhancement): - def reset(self, trade_order_list=None, **kwargs): + def __init__( + self, + step_bar, + start_time=None, + end_time=None, + trade_exchange=None, + **kwargs, + ): + super(TWAPStrategy, self).__init__( + step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs + ) + + def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) TradingEnhancement.reset(self, trade_order_list=trade_order_list) if trade_order_list: self.trade_amount = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len + if trade_exchange: + self.trade_exchange = trade_exchange def generate_order_list(self, **kwargs): - super(TopkDropoutStrategy, self).step() + super(TWAPStrategy, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) order_list = [] for order in self.trade_order_list: @@ -44,8 +58,19 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): TREND_MID = 0 TREND_SHORT = 1 TREND_LONG = 2 + def __init__( + self, + step_bar, + start_time=None, + end_time=None, + trade_exchange=None, + **kwargs, + ): + super(SBBStrategyBase, self).__init__( + step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs + ) - def reset(self, trade_order_list=None, **kwargs): + def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(SBBStrategyBase, self).reset(**kwargs) TradingEnhancement.reset(self, trade_order_list=trade_order_list) if trade_order_list: @@ -54,6 +79,8 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID + if trade_exchange: + self.trade_exchange = trade_exchange def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") @@ -127,11 +154,12 @@ class SBBStrategyEMA(SBBStrategyBase): step_bar, start_time=None, end_time=None, + trade_exchange=None, instruments="csi300", freq="day", **kwargs, ): - super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, **kwargs) + super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs) if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") self.instruments = "all"