1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00
Files
qlib/qlib/strategy/base.py
2021-06-28 17:47:30 +08:00

215 lines
8.0 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, TradeDecison
class BaseStrategy:
"""Base strategy for trading"""
def __init__(
self,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : TradeDecison, optional
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
- If the strategy is used to split trade decison, it will be used
- If the strategy is used for portfolio management, it can be ignored
level_infra : LevelInfrastructure, optional
level shared infrastructure for backtesting, including trade calendar
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
"""
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
def reset_level_infra(self, level_infra):
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if level_infra.has("trade_calendar"):
self.trade_calendar = level_infra.get("trade_calendar")
def reset_common_infra(self, common_infra: CommonInfrastructure):
if not hasattr(self, "common_infra"):
self.common_infra: CommonInfrastructure = common_infra
else:
self.common_infra.update(common_infra)
if common_infra.has("trade_account"):
self.trade_position = common_infra.get("trade_account").current
def reset(
self,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision=None,
**kwargs,
):
"""
- reset `level_infra`, used to reset trade calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
- reset `outer_trade_decision`, used to make split decison
"""
if level_infra is not None:
self.reset_level_infra(level_infra)
if common_infra is not None:
self.reset_common_infra(common_infra)
if outer_trade_decision is not None:
self.outer_trade_decision = outer_trade_decision
def generate_trade_decision(self, execute_result=None):
"""Generate trade decision in each trading bar
Parameters
----------
execute_result : List[object], optional
the executed result for trade decison, by default None
- When call the generate_trade_decision firstly, `execute_result` could be None
"""
raise NotImplementedError("generate_trade_decision is not implemented!")
def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]:
"""
update trade decision in each step of inner execution, this method enable all order
Parameters
----------
trade_decison : TradeDecison
the trade decison that will be updated
trade_calendar : TradeCalendarManager
The calendar of the **inner strategy**!!!!!
Returns
-------
BaseTradeDecision:
"""
# default to return None, which indicates that the trade decision is not changed
return None
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
"""
A method for updating the outer_trade_decision.
The outer strategy may change its decision during updating.
Parameters
----------
outer_trade_decision : BaseTradeDecision
the decision updated by the outer strategy
"""
# default to reset the decision directly
# NOTE: normally, user should do something to the strategy due to the change of outer decision
self.outer_trade_decision = outer_trade_decision
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""
def __init__(
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
def _update_model(self):
"""
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")
class RLStrategy(BaseStrategy):
"""RL-based strategy"""
def __init__(
self,
policy,
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
policy :
RL policy for generate action
"""
super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
class RLIntStrategy(RLStrategy):
"""(RL)-based (Strategy) with (Int)erpreter"""
def __init__(
self,
policy,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
outer_trade_decision: TradeDecison = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
state_interpreter : Union[dict, StateInterpreter]
interpretor that interprets the qlib execute result into rl env state
action_interpreter : Union[dict, ActionInterpreter]
interpretor that interprets the rl agent action into qlib order list
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
"""
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def generate_trade_decision(self, execute_result=None):
_interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
_action = self.policy.step(_interpret_state)
_trade_decision = self.action_interpreter.interpret(action=_action)
return _trade_decision