1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

Add a example to collecting all the decisions

This commit is contained in:
Young
2021-08-15 15:22:48 +00:00
parent 735153a50d
commit 309dfa36cc
5 changed files with 245 additions and 16 deletions

View File

@@ -9,11 +9,11 @@ from .account import Account
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
from .order import BaseTradeDecision
from .position import Position
from .exchange import Exchange
from .backtest import backtest_loop
from .backtest import collect_data_loop
from .order import Order
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..utils import init_instance_by_config
from ..log import get_module_logger
@@ -228,10 +228,13 @@ def backtest(
Returns
-------
report_dict: Report
report: Report
it records the trading report information
indicator_dict: Indicator
It is organized in a dict format
indicator: Indicator
it computes the trading indicator
It is organized in a dict format
"""
trade_strategy, trade_executor = get_strategy_executor(
start_time,
@@ -243,9 +246,9 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
report_dict, indicator_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
report, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return report_dict, indicator_dict
return report, indicator
def collect_data(
@@ -257,6 +260,7 @@ def collect_data(
account=1e9,
exchange_kwargs={},
pos_type: str = "Position",
return_value: dict = None,
):
"""initialize the strategy and executor, then collect the trade decision data for rl training
@@ -277,4 +281,41 @@ def collect_data(
exchange_kwargs,
pos_type=pos_type,
)
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor)
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value)
def format_decisions(
decisions: List[BaseTradeDecision],
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
"""
format the decisions collected by `qlib.backtest.collect_data`
The decisions will be organized into a tree-like structure.
Parameters
----------
decisions : List[BaseTradeDecision]
decisions collected by `qlib.backtest.collect_data`
Returns
-------
Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
reformat the list of decisions into a more user-friendly format
<decisions> := Tuple[<freq>, List[Tuple[<decision>, <sub decisions>]]]
- <sub decisions> := `<decisions> in lower level` | None
- <freq> := "day" | "30min" | "1min" | ...
- <decision> := <instance of BaseTradeDecision>
"""
if len(decisions) == 0:
return None
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
res = (cur_freq, [])
last_dec_idx = 0
for i, dec in enumerate(decisions[1:], 1):
if dec.strategy.trade_calendar.get_freq() == cur_freq:
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i])))
last_dec_idx = i
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
return res

View File

@@ -171,7 +171,7 @@ class BaseSingleMetric:
@property
def empty(self) -> bool:
"""If metric is empyt, return True."""
"""If metric is empty, return True."""
raise NotImplementedError(f"Please implement the `empty` method")
@@ -357,17 +357,17 @@ class PandasSingleMetric:
def __gt__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric < other)
return PandasSingleMetric(self.metric > other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric < other.metric)
return PandasSingleMetric(self.metric > other.metric)
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, (int, float)):
return PandasSingleMetric(self.metric > other)
return PandasSingleMetric(self.metric < other)
elif isinstance(other, PandasSingleMetric):
return PandasSingleMetric(self.metric > other.metric)
return PandasSingleMetric(self.metric < other.metric)
else:
return NotImplemented

View File

@@ -8,17 +8,72 @@ class TestAutoData(unittest.TestCase):
_setup_kwargs = {}
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
provider_uri_1day = "~/.qlib/qlib_data/cn_data" # target_dir
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
@classmethod
def setUpClass(cls) -> None:
def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None:
# use default data
if enable_1d_type == "simple":
provider_uri_day = cls.provider_uri
name_day = "qlib_data_simple"
elif enable_1d_type == "full":
provider_uri_day = cls.provider_uri_1day
name_day = "qlib_data"
else:
raise NotImplementedError(f"This type of input is not supported")
GetData().qlib_data(
name="qlib_data_simple",
name=name_day,
region=REG_CN,
interval="1d",
target_dir=cls.provider_uri,
target_dir=provider_uri_day,
delete_old=False,
exists_skip=True,
)
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
if enable_1min:
GetData().qlib_data(
name="qlib_data",
region=REG_CN,
interval="1min",
target_dir=cls.provider_uri_1min,
delete_old=False,
exists_skip=True,
)
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
client_config = {
"calendar_provider": {
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
"feature_provider": {
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
}
init(
provider_uri=cls.provider_uri,
region=REG_CN,
expression_cache=None,
dataset_cache=None,
**client_config,
**cls._setup_kwargs,
)

View File

@@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data
class GetData:
DATASET_VERSION = "v1"
DATASET_VERSION = "v2"
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"

View File

@@ -0,0 +1,133 @@
from typing import List, Tuple, Union
from qlib.backtest.position import Position
from qlib.backtest import collect_data, format_decisions
from qlib.backtest.order import BaseTradeDecision, TradeRangeByTime
import qlib
from qlib.tests import TestAutoData
import unittest
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
import pandas as pd
@unittest.skip("This test takes a lot of time due to the large size of high-frequency data")
class TestHFBacktest(TestAutoData):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass(enable_1min=True, enable_1d_type="full")
def _gen_orders(self, inst, date, pos) -> pd.DataFrame:
headers = [
"datetime",
"instrument",
"amount",
"direction",
]
orders = [
[date, inst, pos, "sell"],
]
return pd.DataFrame(orders, columns=headers)
def test_trading(self):
# date = "2020-02-03"
# inst = "SH600068"
# pos = 2.0167
pos = 100000
inst, date = "SH600519", "2021-01-18"
market = [inst]
start_time = f"{date}"
end_time = f"{date} 15:00" # include the high-freq data on the end day
freq_l0 = "day"
freq_l1 = "30min"
freq_l2 = "1min"
orders = self._gen_orders(inst=inst, date=date, pos=pos * 0.90)
strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"trade_range": TradeRangeByTime("10:45", "14:44"),
"file": orders,
},
}
backtest_config = {
"start_time": start_time,
"end_time": end_time,
"account": {
"cash": 0,
inst: pos,
},
"benchmark": None, # benchmark is not required here for trading
"exchange_kwargs": {
"freq": freq_l2, # use the most fine-grained data as the exchange
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
"codes": market,
"trade_unit": 100,
},
# "pos_type": "InfPosition" # Position with infinitive position
}
executor_config = {
"class": "NestedExecutor", # Level 1 Order execution
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq_l0,
"inner_executor": {
"class": "NestedExecutor", # Leve 2 Order Execution
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq_l1,
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq_l2,
"generate_report": False,
"verbose": True,
"indicator_config": {
"show_indicator": False,
},
"track_data": True,
},
},
"inner_strategy": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"generate_report": False,
"indicator_config": {
"show_indicator": True,
},
"track_data": True,
},
},
"inner_strategy": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"generate_report": False,
"indicator_config": {
"show_indicator": True,
},
"track_data": True,
},
}
ret_val = {}
decisions = list(
collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val)
)
report, indicator = ret_val["report"], ret_val["indicator"]
# NOTE: please refer to the docs of format_decisions
# NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!!
f_dec = format_decisions(decisions)
print(indicator["1day"])
if __name__ == "__main__":
unittest.main()