From 309dfa36cc5175541579de4e35aef377ad5bfed4 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 15 Aug 2021 15:22:48 +0000 Subject: [PATCH] Add a example to collecting all the decisions --- qlib/backtest/__init__.py | 53 ++++++++- qlib/backtest/high_performance_ds.py | 10 +- qlib/tests/__init__.py | 63 ++++++++++- qlib/tests/data.py | 2 +- tests/backtest/test_high_freq_trading.py | 133 +++++++++++++++++++++++ 5 files changed, 245 insertions(+), 16 deletions(-) create mode 100644 tests/backtest/test_high_freq_trading.py diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index bcd07fa23..cd113c8ab 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -9,11 +9,11 @@ from .account import Account if TYPE_CHECKING: from ..strategy.base import BaseStrategy from .executor import BaseExecutor + from .order import BaseTradeDecision from .position import Position from .exchange import Exchange from .backtest import backtest_loop from .backtest import collect_data_loop -from .order import Order from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..utils import init_instance_by_config from ..log import get_module_logger @@ -228,10 +228,13 @@ def backtest( Returns ------- - report_dict: Report + report: Report it records the trading report information - indicator_dict: Indicator + It is organized in a dict format + indicator: Indicator it computes the trading indicator + It is organized in a dict format + """ trade_strategy, trade_executor = get_strategy_executor( start_time, @@ -243,9 +246,9 @@ def backtest( exchange_kwargs, pos_type=pos_type, ) - report_dict, indicator_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor) + report, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor) - return report_dict, indicator_dict + return report, indicator def collect_data( @@ -257,6 +260,7 @@ def collect_data( account=1e9, exchange_kwargs={}, pos_type: str = "Position", + return_value: dict = None, ): """initialize the strategy and executor, then collect the trade decision data for rl training @@ -277,4 +281,41 @@ def collect_data( exchange_kwargs, pos_type=pos_type, ) - yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor) + yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value) + + +def format_decisions( + decisions: List[BaseTradeDecision], +) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]: + """ + format the decisions collected by `qlib.backtest.collect_data` + The decisions will be organized into a tree-like structure. + + Parameters + ---------- + decisions : List[BaseTradeDecision] + decisions collected by `qlib.backtest.collect_data` + + Returns + ------- + Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]: + + reformat the list of decisions into a more user-friendly format + := Tuple[, List[Tuple[, ]]] + - := ` in lower level` | None + - := "day" | "30min" | "1min" | ... + - := + """ + if len(decisions) == 0: + return None + + cur_freq = decisions[0].strategy.trade_calendar.get_freq() + + res = (cur_freq, []) + last_dec_idx = 0 + for i, dec in enumerate(decisions[1:], 1): + if dec.strategy.trade_calendar.get_freq() == cur_freq: + res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i]))) + last_dec_idx = i + res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :]))) + return res diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index c60d3f97e..eabe84a0a 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -171,7 +171,7 @@ class BaseSingleMetric: @property def empty(self) -> bool: - """If metric is empyt, return True.""" + """If metric is empty, return True.""" raise NotImplementedError(f"Please implement the `empty` method") @@ -357,17 +357,17 @@ class PandasSingleMetric: def __gt__(self, other): if isinstance(other, (int, float)): - return PandasSingleMetric(self.metric < other) + return PandasSingleMetric(self.metric > other) elif isinstance(other, PandasSingleMetric): - return PandasSingleMetric(self.metric < other.metric) + return PandasSingleMetric(self.metric > other.metric) else: return NotImplemented def __lt__(self, other): if isinstance(other, (int, float)): - return PandasSingleMetric(self.metric > other) + return PandasSingleMetric(self.metric < other) elif isinstance(other, PandasSingleMetric): - return PandasSingleMetric(self.metric > other.metric) + return PandasSingleMetric(self.metric < other.metric) else: return NotImplemented diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index 7f43cd99a..cc452ae0f 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -8,17 +8,72 @@ class TestAutoData(unittest.TestCase): _setup_kwargs = {} provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir + provider_uri_1day = "~/.qlib/qlib_data/cn_data" # target_dir + provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min" @classmethod - def setUpClass(cls) -> None: + def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None: # use default data + if enable_1d_type == "simple": + provider_uri_day = cls.provider_uri + name_day = "qlib_data_simple" + elif enable_1d_type == "full": + provider_uri_day = cls.provider_uri_1day + name_day = "qlib_data" + else: + raise NotImplementedError(f"This type of input is not supported") + GetData().qlib_data( - name="qlib_data_simple", + name=name_day, region=REG_CN, interval="1d", - target_dir=cls.provider_uri, + target_dir=provider_uri_day, delete_old=False, exists_skip=True, ) - init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) + + if enable_1min: + GetData().qlib_data( + name="qlib_data", + region=REG_CN, + interval="1min", + target_dir=cls.provider_uri_1min, + delete_old=False, + exists_skip=True, + ) + + provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day} + + client_config = { + "calendar_provider": { + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + "feature_provider": { + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + } + init( + provider_uri=cls.provider_uri, + region=REG_CN, + expression_cache=None, + dataset_cache=None, + **client_config, + **cls._setup_kwargs, + ) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 2bfe43590..b38fd7eee 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data class GetData: - DATASET_VERSION = "v1" + DATASET_VERSION = "v2" REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads" QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip" diff --git a/tests/backtest/test_high_freq_trading.py b/tests/backtest/test_high_freq_trading.py new file mode 100644 index 000000000..628ec1e78 --- /dev/null +++ b/tests/backtest/test_high_freq_trading.py @@ -0,0 +1,133 @@ +from typing import List, Tuple, Union +from qlib.backtest.position import Position +from qlib.backtest import collect_data, format_decisions +from qlib.backtest.order import BaseTradeDecision, TradeRangeByTime +import qlib +from qlib.tests import TestAutoData +import unittest +from qlib.config import REG_CN, HIGH_FREQ_CONFIG +import pandas as pd + + +@unittest.skip("This test takes a lot of time due to the large size of high-frequency data") +class TestHFBacktest(TestAutoData): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass(enable_1min=True, enable_1d_type="full") + + def _gen_orders(self, inst, date, pos) -> pd.DataFrame: + headers = [ + "datetime", + "instrument", + "amount", + "direction", + ] + orders = [ + [date, inst, pos, "sell"], + ] + return pd.DataFrame(orders, columns=headers) + + def test_trading(self): + + # date = "2020-02-03" + # inst = "SH600068" + # pos = 2.0167 + pos = 100000 + inst, date = "SH600519", "2021-01-18" + market = [inst] + + start_time = f"{date}" + end_time = f"{date} 15:00" # include the high-freq data on the end day + freq_l0 = "day" + freq_l1 = "30min" + freq_l2 = "1min" + + orders = self._gen_orders(inst=inst, date=date, pos=pos * 0.90) + + strategy_config = { + "class": "FileOrderStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": { + "trade_range": TradeRangeByTime("10:45", "14:44"), + "file": orders, + }, + } + backtest_config = { + "start_time": start_time, + "end_time": end_time, + "account": { + "cash": 0, + inst: pos, + }, + "benchmark": None, # benchmark is not required here for trading + "exchange_kwargs": { + "freq": freq_l2, # use the most fine-grained data as the exchange + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + "codes": market, + "trade_unit": 100, + }, + # "pos_type": "InfPosition" # Position with infinitive position + } + executor_config = { + "class": "NestedExecutor", # Level 1 Order execution + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq_l0, + "inner_executor": { + "class": "NestedExecutor", # Leve 2 Order Execution + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq_l1, + "inner_executor": { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq_l2, + "generate_report": False, + "verbose": True, + "indicator_config": { + "show_indicator": False, + }, + "track_data": True, + }, + }, + "inner_strategy": { + "class": "TWAPStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + }, + "generate_report": False, + "indicator_config": { + "show_indicator": True, + }, + "track_data": True, + }, + }, + "inner_strategy": { + "class": "TWAPStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + }, + "generate_report": False, + "indicator_config": { + "show_indicator": True, + }, + "track_data": True, + }, + } + + ret_val = {} + decisions = list( + collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val) + ) + report, indicator = ret_val["report"], ret_val["indicator"] + # NOTE: please refer to the docs of format_decisions + # NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!! + f_dec = format_decisions(decisions) + print(indicator["1day"]) + + +if __name__ == "__main__": + unittest.main()