diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index ef6906018..72b6067b3 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -151,7 +151,7 @@ class NestedDecisionExecutionWorkflow: self._train_model(model, dataset) strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, @@ -189,7 +189,7 @@ class NestedDecisionExecutionWorkflow: backtest_config["benchmark"] = self.benchmark strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 486e694a7..248f0423f 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -31,7 +31,7 @@ if __name__ == "__main__": }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py new file mode 100644 index 000000000..192f690c4 --- /dev/null +++ b/qlib/backtest/signal.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Union +from ..model.base import BaseModel +from ..data.dataset import Dataset +from ..data.dataset.utils import convert_index_format +from ..utils.resam import resam_ts_data +import pandas as pd +import abc + + +class Signal(metaclass=abc.ABCMeta): + """ + Some trading strategy make decisions based on other prediction signals + The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset) + + This + """ + + @abc.abstractmethod + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]: + """ + get the signal at the end of the decision step(from `start_time` to `end_time`) + + Returns + ------- + Union[pd.Series, pd.DataFrame, None]: + returns None if no signal in the specific day + """ + ... + + +class SignalWCache(Signal): + """ + Signal With pandas with based Cache + SignalWCache will store the prepared signal as a attribute and give the according signal based on input query + """ + + def __init__(self, signal: Union[pd.Series, pd.DataFrame]): + """ + + Parameters + ---------- + signal : Union[pd.Series, pd.DataFrame] + The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted) + + instrument datetime + SH600000 2008-01-02 0.079704 + 2008-01-03 0.120125 + 2008-01-04 0.878860 + 2008-01-07 0.505539 + 2008-01-08 0.395004 + """ + self.signal_cache = convert_index_format(signal, level="datetime") + + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]: + # the frequency of the signal may not algin with the decision frequency of strategy + # so resampling from the data is necessary + # the latest signal leverage more recent data and therefore is used in trading. + signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") + return signal + + +class ModelSignal(SignalWCache): + ... + + def __init__(self, model: BaseModel, dataset: Dataset): + self.model = model + self.dataset = dataset + pred_scores = self.model.predict(dataset) + if isinstance(pred_scores, pd.DataFrame): + pred_scores = pred_scores.iloc[:, 0] + super().__init__(pred_scores) + + def _update_model(self): + """ + When using online data, update model in each bar as the following steps: + - update dataset with online data, the dataset should support online update + - make the latest prediction scores of the new bar + - update the pred score into the latest prediction + """ + # TODO: this method is not included in the framework and could be refactor later + raise NotImplementedError("_update_model is not implemented!") diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index e308c1a05..adc1679c1 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .model_strategy import ( +from .signal_strategy import ( TopkDropoutStrategy, WeightStrategyBase, ) diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index b45c03ae9..aaebe3543 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -6,7 +6,7 @@ This strategy is not well maintained from .order_generator import OrderGenWInteract -from .model_strategy import WeightStrategyBase +from .signal_strategy import WeightStrategyBase import copy diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 23fdd2991..dcf4667ff 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from pathlib import Path import warnings import numpy as np diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/signal_strategy.py similarity index 91% rename from qlib/contrib/strategy/model_strategy.py rename to qlib/contrib/strategy/signal_strategy.py index 1d22153a7..1adfc517e 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -1,27 +1,33 @@ import copy +from qlib.backtest.signal import ModelSignal, Signal, SignalWCache +from typing import Union +from qlib.data.dataset import Dataset +from qlib.model.base import BaseModel from qlib.backtest.position import Position import warnings import numpy as np import pandas as pd from ...utils.resam import resam_ts_data -from ...strategy.base import ModelStrategy +from ...strategy.base import BaseStrategy from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO from .order_generator import OrderGenWInteract -class TopkDropoutStrategy(ModelStrategy): +class TopkDropoutStrategy(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, topk, n_drop, + model: BaseModel = None, + dataset: Dataset = None, + signal: Union[pd.DataFrame, pd.Series] = None, method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -64,7 +70,7 @@ class TopkDropoutStrategy(ModelStrategy): """ super(TopkDropoutStrategy, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) self.topk = topk self.n_drop = n_drop @@ -73,6 +79,8 @@ class TopkDropoutStrategy(ModelStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable + assert signal is not None or dataset is not None and model is not None + self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -87,7 +95,7 @@ class TopkDropoutStrategy(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) if self.only_tradable: @@ -235,15 +243,17 @@ class TopkDropoutStrategy(ModelStrategy): return TradeDecisionWO(sell_order_list + buy_order_list, self) -class WeightStrategyBase(ModelStrategy): +class WeightStrategyBase(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, + model: BaseModel = None, + dataset: Dataset = None, + signal: Union[pd.DataFrame, pd.Series] = None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, level_infra=None, @@ -260,12 +270,14 @@ class WeightStrategyBase(ModelStrategy): - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ super(WeightStrategyBase, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + assert signal is not None or dataset is not None and model is not None + self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -298,7 +310,7 @@ class WeightStrategyBase(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) current_temp = copy.deepcopy(self.trade_position) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index bd5d3dbd3..860477544 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -17,7 +17,7 @@ from ..utils import init_instance_by_config from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..backtest.decision import BaseTradeDecision -__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"] +__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"] class BaseStrategy: @@ -194,45 +194,6 @@ class BaseStrategy: return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) -class ModelStrategy(BaseStrategy): - """Model-based trading strategy, use model to make predictions for trading""" - - def __init__( - self, - model: BaseModel, - dataset: DatasetH, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - **kwargs, - ): - """ - Parameters - ---------- - model : BaseModel - the model used in when making predictions - dataset : DatasetH - provide test data for model - kwargs : dict - arguments that will be passed into `reset` method - """ - super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) - self.model = model - self.dataset = dataset - self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") - if isinstance(self.pred_scores, pd.DataFrame): - self.pred_scores = self.pred_scores.iloc[:, 0] - - def _update_model(self): - """ - When using online data, pdate model in each bar as the following steps: - - update dataset with online data, the dataset should support online update - - make the latest prediction scores of the new bar - - update the pred score into the latest prediction - """ - raise NotImplementedError("_update_model is not implemented!") - - class RLStrategy(BaseStrategy): """RL-based strategy""" diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index da68139a8..69de8b129 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -144,7 +144,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset,