mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add infra interface & fix no KeyboardInterpret bug
This commit is contained in:
@@ -1,13 +1,6 @@
|
||||
# Multi-level Trading
|
||||
|
||||
This worflow is an example for multi-level trading.
|
||||
|
||||
## Introduction
|
||||
|
||||
Qlib supports backtesting of various strategies, including portfolio management strategies, order split strategies, model-based strategies (such as deep learning models), rule-based strategies, and RL-based strategies.
|
||||
|
||||
And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies.
|
||||
# Nested Decision Execution
|
||||
|
||||
This worflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
|
||||
|
||||
## Weekly Portfolio Generation and Daily Order Execution
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
from qlib import backtest
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.data import D
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
@@ -14,7 +13,7 @@ from qlib.tests.data import GetData
|
||||
from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class MultiLevelTradingWorkflow:
|
||||
class NestedDecisonExecutionWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
@@ -172,7 +171,7 @@ class MultiLevelTradingWorkflow:
|
||||
print(f"Qlib data is not found in {provider_uri_1min}")
|
||||
GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
|
||||
|
||||
# TODO: update new data
|
||||
# TODO: update latest data
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri_day):
|
||||
print(f"Qlib data is not found in {provider_uri_day}")
|
||||
@@ -260,4 +259,4 @@ class MultiLevelTradingWorkflow:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(MultiLevelTradingWorkflow)
|
||||
fire.Fire(NestedDecisonExecutionWorkflow)
|
||||
@@ -7,6 +7,7 @@ from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
from .backtest import collect_data as data_generator
|
||||
|
||||
from .utils import CommonInfrastructure
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
@@ -101,10 +102,7 @@ def get_strategy_executor(
|
||||
)
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = {
|
||||
"trade_account": trade_account,
|
||||
"trade_exchange": trade_exchange,
|
||||
}
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..utils.resam import parse_freq
|
||||
|
||||
from .order import Order
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -23,7 +23,7 @@ class BaseExecutor:
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: dict = {},
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -39,7 +39,7 @@ class BaseExecutor:
|
||||
whether to generate trade_decision, will be used when making data for multi-level training
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
common_infra : dict, optional:
|
||||
common_infra : CommonInfrastructure, optional:
|
||||
common infrastructure for backtesting, may including:
|
||||
- trade_account : Account, optional
|
||||
trade account for trading
|
||||
@@ -63,11 +63,11 @@ class BaseExecutor:
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if "trade_account" in common_infra:
|
||||
if common_infra.has("trade_account"):
|
||||
self.trade_account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True)
|
||||
|
||||
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
|
||||
def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `track_data`, used when making data for multi-level training
|
||||
@@ -88,7 +88,7 @@ class BaseExecutor:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
return {"trade_calendar": self.trade_calendar}
|
||||
return LevelInfrastructure(trade_calendar=self.trade_calendar)
|
||||
|
||||
def finished(self):
|
||||
return self.trade_calendar.finished()
|
||||
@@ -138,7 +138,7 @@ class NestedExecutor(BaseExecutor):
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: dict = {},
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -182,7 +182,7 @@ class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
super(NestedExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if self.generate_report and "trade_exchange" in common_infra:
|
||||
if self.generate_report and common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
self.inner_executor.reset_common_infra(common_infra)
|
||||
@@ -257,7 +257,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: dict = {},
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -286,7 +286,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
- reset trade_exchange
|
||||
"""
|
||||
super(SimulatorExecutor, self).reset_common_infra(common_infra)
|
||||
if "trade_exchange" in common_infra:
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def execute(self, trade_decision):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
@@ -96,3 +97,46 @@ class TradeCalendarManager:
|
||||
def get_all_time(self):
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs):
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
def get_support_infra(self):
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs):
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(f"{k} is ignored in `reset_infra`!")
|
||||
|
||||
def get(self, infra_name):
|
||||
if hasattr(self, infra_name):
|
||||
return getattr(self, infra_name)
|
||||
else:
|
||||
warnings.warn(f"infra {infra_name} is not found!")
|
||||
|
||||
def has(self, infra_name):
|
||||
if infra_name in self.get_support_infra() and hasattr(self, infra_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update(self, other):
|
||||
support_infra = other.get_support_infra()
|
||||
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
|
||||
self.reset_infra(**infra_dict)
|
||||
|
||||
|
||||
class CommonInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_account", "trade_exchange"]
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_calendar"]
|
||||
|
||||
@@ -18,8 +18,8 @@ class SoftTopkStrategy(WeightStrategyBase):
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Parameter
|
||||
|
||||
@@ -22,8 +22,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if "trade_exchange" in common_infra:
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
@@ -249,8 +249,8 @@ class WeightStrategyBase(ModelStrategy):
|
||||
dataset,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
@@ -274,7 +274,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
super(WeightStrategyBase, self).reset_common_infra(common_infra)
|
||||
|
||||
if "trade_exchange" in common_infra:
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
@@ -6,6 +7,7 @@ from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
@@ -13,17 +15,20 @@ class TWAPStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
@@ -36,21 +41,21 @@ class TWAPStrategy(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(TWAPStrategy, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: object = None, **kwargs):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : object, optional
|
||||
outer_trade_decision : List[Order], optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
@@ -127,14 +132,16 @@ class SBBStrategyBase(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -156,15 +163,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision=None, **kwargs):
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : object, optional
|
||||
outer_trade_decision : List[Order], optional
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
@@ -324,18 +330,18 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision=[],
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
outer_trade_decision: List[Order] = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
instruments : str, optional
|
||||
instruments : Union[List, str], optional
|
||||
instruments of EMA signal, by default "csi300"
|
||||
freq : str, optional
|
||||
freq of EMA signal, by default "day"
|
||||
@@ -375,7 +381,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -15,8 +16,8 @@ class BaseStrategy:
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -25,9 +26,9 @@ class BaseStrategy:
|
||||
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : dict, optional
|
||||
level_infra : LevelInfrastructure, optional
|
||||
level shared infrastructure for backtesting, including trade calendar
|
||||
common_infra : dict, optional
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
|
||||
"""
|
||||
|
||||
@@ -39,7 +40,7 @@ class BaseStrategy:
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
@@ -48,10 +49,16 @@ class BaseStrategy:
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if "trade_account" in common_infra:
|
||||
if common_infra.has("trade_account"):
|
||||
self.trade_position = common_infra.get("trade_account").current
|
||||
|
||||
def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs):
|
||||
def reset(
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -86,8 +93,8 @@ class ModelStrategy(BaseStrategy):
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -122,8 +129,8 @@ class RLStrategy(BaseStrategy):
|
||||
self,
|
||||
policy,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -145,8 +152,8 @@ class RLIntStrategy(RLStrategy):
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -46,3 +46,4 @@ def experiment_kill_signal_handler(signum, frame):
|
||||
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
|
||||
"""
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
Reference in New Issue
Block a user