diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 3892fde41..18573115b 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -1,9 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from qlib.backtest.utils import TradeDecison +from qlib.strategy.base import BaseStrategy +from qlib.backtest.executor import BaseExecutor from ..utils.resam import parse_freq -def backtest_loop(start_time, end_time, trade_strategy, trade_executor): +def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor): """backtest funciton for the interaction of the outermost strategy and executor in the nested decison execution Returns @@ -17,7 +20,7 @@ def backtest_loop(start_time, end_time, trade_strategy, trade_executor): return return_value.get("report"), return_value.get("indicator") -def collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value: dict = None): +def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None): """Generator for collecting the trade decision data for rl training Parameters @@ -44,7 +47,7 @@ def collect_data_loop(start_time, end_time, trade_strategy, trade_executor, retu _execute_result = None while not trade_executor.finished(): - _trade_decision = trade_strategy.generate_trade_decision(_execute_result) + _trade_decision: TradeDecison = trade_strategy.generate_trade_decision(_execute_result) _execute_result = yield from trade_executor.collect_data(_trade_decision) if return_value is not None: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 5cc2c00c3..d86d5e25a 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -5,7 +5,7 @@ from typing import Union from .order import Order from .exchange import Exchange -from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison +from .utils import BaseTradeDecision, TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison from ..utils import init_instance_by_config from ..utils.resam import parse_freq @@ -265,7 +265,7 @@ class NestedExecutor(BaseExecutor): pass return return_value.get("execute_result") - def collect_data(self, trade_decision, return_value=None): + def collect_data(self, trade_decision: BaseTradeDecision, return_value=None): if self.track_data: yield trade_decision self._init_sub_trading(trade_decision) @@ -273,6 +273,14 @@ class NestedExecutor(BaseExecutor): inner_order_indicators = [] _inner_execute_result = None while not self.inner_executor.finished(): + # outter strategy have chance to update decision each iterator + updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar) + if updated_trade_decision is not None: + trade_decision = updated_trade_decision + # NEW UPDATE + # create a hook for inner strategy to update outter decision + self.inner_strategy.alter_decision(trade_decision) + _inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result) _inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision) execute_result.extend(_inner_execute_result) diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 120f80609..f524d09fe 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -1,10 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from re import L +from qlib.strategy.base import BaseStrategy +from qlib.backtest.exchange import Exchange +from qlib.backtest.account import Account import pandas as pd import warnings -from typing import Union, List, Set +from typing import Tuple, Union, List, Set from ..utils.resam import get_resam_calendar from ..data.data import Cal @@ -138,6 +140,7 @@ class BaseInfrastructure: self.reset_infra(**infra_dict) + class CommonInfrastructure(BaseInfrastructure): def get_support_infra(self): return ["trade_account", "trade_exchange"] @@ -148,8 +151,63 @@ class LevelInfrastructure(BaseInfrastructure): return ["trade_calendar"] -class TradeDecison: - """trade decison that made by strategy""" +class BaseTradeDecision: + # TODO: put it into order.py; and replace it with decision.py + def __init__(self, strategy: BaseStrategy): + self.strategy = strategy + + def get_decision(self) -> List[object]: + """ + get the concrete decision of the order + This will be called by the inner strategy + + Returns + ------- + List[object]: + The decision result. Typically it is some orders + Example: + []: + Decision not available + concrete_decision: + available + """ + raise NotImplementedError(f"This type of input is not supported") + + NOT_AVAIL = 0 + NO_UPDATE = 1 + NEW_UPDATE = 2 + def update(self, trade_step: int, trade_len: int) -> "BaseTradeDecison": + """ + Be called at the **start** of each step + + Returns + ------- + None: + No update, use previous decision(or unavailable) + BaseTradeDecison: + New update, use new decision + """ + return self.strategy.update_trade_decision(self, trade_step, trade_len) + + def get_range_limit(self) -> Tuple[int, int]: + """ + return the expected step range for limiting the dealing time of the order + + Returns + ------- + Tuple[int, int]: + + + Raises + ------ + NotImplementedError: + If the decision can't provide a unified start and end + """ + raise NotImplementedError(f"This type of input is not supported") + + +class TradeDecison(BaseTradeDecision): + """trade decision that made by strategy""" def __init__(self, order_list, ori_strategy, init_enable=False): """ diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 01eb42803..ad3e06ce1 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,7 +1,7 @@ import warnings import numpy as np import pandas as pd -from typing import List, Union +from typing import List, Tuple, Union from ...utils.resam import resam_ts_data from ...data.data import D @@ -597,3 +597,41 @@ class ACStrategy(BaseStrategy): ) order_list.append(_order) return TradeDecison(order_list=order_list, ori_strategy=self) + + +class RandomOrderStrategy(BaseStrategy): + + def __init__(self, + time_range: Tuple = ("9:30", "15:00"), # left closed and right closed. + sample_ratio: float = 1., + volume_ratio: float = 0.01, + market: str = "all", + *args, + **kwargs): + """ + Parameters + ---------- + time_range : Tuple + the intra day time range of the orders + the left and right is closed. + sample_ratio : float + the ratio of all orders are sampled + volume_ratio : float + the volume of the total day + raito of the total volume of a specific day + market : str + stock pool for sampling + """ + + super().__init__(*args, **kwargs) + self.time_range = time_range + self.sample_ratio = sample_ratio + self.volume_ratio = volume_ratio + self.market = market + exch: Exchange = self.common_infra.get("exchange") + self.volume = D.features(D.instruments("market"), ["Mean($volume, 10)"], start_time=exch.start_time, end_time=exch.end_time) + + def generate_trade_decision(self, execute_result=None): + + + return super().generate_trade_decision(execute_result=execute_result) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 9f9feb3b1..6c8917658 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Union +from typing import List, Union from ..model.base import BaseModel from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config -from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison +from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeDecison class BaseStrategy: @@ -43,9 +43,9 @@ class BaseStrategy: if level_infra.has("trade_calendar"): self.trade_calendar = level_infra.get("trade_calendar") - def reset_common_infra(self, common_infra): + def reset_common_infra(self, common_infra: CommonInfrastructure): if not hasattr(self, "common_infra"): - self.common_infra = common_infra + self.common_infra: CommonInfrastructure = common_infra else: self.common_infra.update(common_infra) @@ -84,17 +84,32 @@ class BaseStrategy: """ raise NotImplementedError("generate_trade_decision is not implemented!") - def update_trade_decision(self, trade_decison: TradeDecison, trade_step, trade_len): + def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_step: int, trade_len: int) -> BaseTradeDecision: """update trade decision in each step of inner execution, this method enable all order Parameters ---------- trade_decison : TradeDecison the trade decison that will be updated + Returns + ------- + BaseTradeDecision: """ if trade_step == 0: trade_decison.enable(all_enable=True) + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision): + """ + A method for updating the outer_trade_decision. + The outer strategy may change its decision during updating. + + Parameters + ---------- + outer_trade_decision : BaseTradeDecision + the decision updated by the outer strategy + """ + self.outer_trade_decision = outer_trade_decision + class ModelStrategy(BaseStrategy): """Model-based trading strategy, use model to make predictions for trading"""