mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
Add a example to collecting all the decisions
This commit is contained in:
@@ -9,11 +9,11 @@ from .account import Account
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .order import BaseTradeDecision
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .order import Order
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
@@ -228,10 +228,13 @@ def backtest(
|
||||
|
||||
Returns
|
||||
-------
|
||||
report_dict: Report
|
||||
report: Report
|
||||
it records the trading report information
|
||||
indicator_dict: Indicator
|
||||
It is organized in a dict format
|
||||
indicator: Indicator
|
||||
it computes the trading indicator
|
||||
It is organized in a dict format
|
||||
|
||||
"""
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time,
|
||||
@@ -243,9 +246,9 @@ def backtest(
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
report_dict, indicator_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
report, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict, indicator_dict
|
||||
return report, indicator
|
||||
|
||||
|
||||
def collect_data(
|
||||
@@ -257,6 +260,7 @@ def collect_data(
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
):
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
@@ -277,4 +281,41 @@ def collect_data(
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value)
|
||||
|
||||
|
||||
def format_decisions(
|
||||
decisions: List[BaseTradeDecision],
|
||||
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
|
||||
"""
|
||||
format the decisions collected by `qlib.backtest.collect_data`
|
||||
The decisions will be organized into a tree-like structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decisions : List[BaseTradeDecision]
|
||||
decisions collected by `qlib.backtest.collect_data`
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
|
||||
|
||||
reformat the list of decisions into a more user-friendly format
|
||||
<decisions> := Tuple[<freq>, List[Tuple[<decision>, <sub decisions>]]]
|
||||
- <sub decisions> := `<decisions> in lower level` | None
|
||||
- <freq> := "day" | "30min" | "1min" | ...
|
||||
- <decision> := <instance of BaseTradeDecision>
|
||||
"""
|
||||
if len(decisions) == 0:
|
||||
return None
|
||||
|
||||
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
||||
|
||||
res = (cur_freq, [])
|
||||
last_dec_idx = 0
|
||||
for i, dec in enumerate(decisions[1:], 1):
|
||||
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
||||
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i])))
|
||||
last_dec_idx = i
|
||||
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
|
||||
return res
|
||||
|
||||
@@ -171,7 +171,7 @@ class BaseSingleMetric:
|
||||
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
"""If metric is empyt, return True."""
|
||||
"""If metric is empty, return True."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
@@ -357,17 +357,17 @@ class PandasSingleMetric:
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
return PandasSingleMetric(self.metric < other)
|
||||
return PandasSingleMetric(self.metric > other)
|
||||
elif isinstance(other, PandasSingleMetric):
|
||||
return PandasSingleMetric(self.metric < other.metric)
|
||||
return PandasSingleMetric(self.metric > other.metric)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
return PandasSingleMetric(self.metric > other)
|
||||
return PandasSingleMetric(self.metric < other)
|
||||
elif isinstance(other, PandasSingleMetric):
|
||||
return PandasSingleMetric(self.metric > other.metric)
|
||||
return PandasSingleMetric(self.metric < other.metric)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@@ -8,17 +8,72 @@ class TestAutoData(unittest.TestCase):
|
||||
|
||||
_setup_kwargs = {}
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
provider_uri_1day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None:
|
||||
# use default data
|
||||
|
||||
if enable_1d_type == "simple":
|
||||
provider_uri_day = cls.provider_uri
|
||||
name_day = "qlib_data_simple"
|
||||
elif enable_1d_type == "full":
|
||||
provider_uri_day = cls.provider_uri_1day
|
||||
name_day = "qlib_data"
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
name=name_day,
|
||||
region=REG_CN,
|
||||
interval="1d",
|
||||
target_dir=cls.provider_uri,
|
||||
target_dir=provider_uri_day,
|
||||
delete_old=False,
|
||||
exists_skip=True,
|
||||
)
|
||||
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
if enable_1min:
|
||||
GetData().qlib_data(
|
||||
name="qlib_data",
|
||||
region=REG_CN,
|
||||
interval="1min",
|
||||
target_dir=cls.provider_uri_1min,
|
||||
delete_old=False,
|
||||
exists_skip=True,
|
||||
)
|
||||
|
||||
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
|
||||
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
init(
|
||||
provider_uri=cls.provider_uri,
|
||||
region=REG_CN,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
**client_config,
|
||||
**cls._setup_kwargs,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v1"
|
||||
DATASET_VERSION = "v2"
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
|
||||
133
tests/backtest/test_high_freq_trading.py
Normal file
133
tests/backtest/test_high_freq_trading.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import List, Tuple, Union
|
||||
from qlib.backtest.position import Position
|
||||
from qlib.backtest import collect_data, format_decisions
|
||||
from qlib.backtest.order import BaseTradeDecision, TradeRangeByTime
|
||||
import qlib
|
||||
from qlib.tests import TestAutoData
|
||||
import unittest
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@unittest.skip("This test takes a lot of time due to the large size of high-frequency data")
|
||||
class TestHFBacktest(TestAutoData):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass(enable_1min=True, enable_1d_type="full")
|
||||
|
||||
def _gen_orders(self, inst, date, pos) -> pd.DataFrame:
|
||||
headers = [
|
||||
"datetime",
|
||||
"instrument",
|
||||
"amount",
|
||||
"direction",
|
||||
]
|
||||
orders = [
|
||||
[date, inst, pos, "sell"],
|
||||
]
|
||||
return pd.DataFrame(orders, columns=headers)
|
||||
|
||||
def test_trading(self):
|
||||
|
||||
# date = "2020-02-03"
|
||||
# inst = "SH600068"
|
||||
# pos = 2.0167
|
||||
pos = 100000
|
||||
inst, date = "SH600519", "2021-01-18"
|
||||
market = [inst]
|
||||
|
||||
start_time = f"{date}"
|
||||
end_time = f"{date} 15:00" # include the high-freq data on the end day
|
||||
freq_l0 = "day"
|
||||
freq_l1 = "30min"
|
||||
freq_l2 = "1min"
|
||||
|
||||
orders = self._gen_orders(inst=inst, date=date, pos=pos * 0.90)
|
||||
|
||||
strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"trade_range": TradeRangeByTime("10:45", "14:44"),
|
||||
"file": orders,
|
||||
},
|
||||
}
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": {
|
||||
"cash": 0,
|
||||
inst: pos,
|
||||
},
|
||||
"benchmark": None, # benchmark is not required here for trading
|
||||
"exchange_kwargs": {
|
||||
"freq": freq_l2, # use the most fine-grained data as the exchange
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": market,
|
||||
"trade_unit": 100,
|
||||
},
|
||||
# "pos_type": "InfPosition" # Position with infinitive position
|
||||
}
|
||||
executor_config = {
|
||||
"class": "NestedExecutor", # Level 1 Order execution
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq_l0,
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor", # Leve 2 Order Execution
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq_l1,
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq_l2,
|
||||
"generate_report": False,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"generate_report": False,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"generate_report": False,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
ret_val = {}
|
||||
decisions = list(
|
||||
collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val)
|
||||
)
|
||||
report, indicator = ret_val["report"], ret_val["indicator"]
|
||||
# NOTE: please refer to the docs of format_decisions
|
||||
# NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!!
|
||||
f_dec = format_decisions(decisions)
|
||||
print(indicator["1day"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user