mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Make static prediction easier
This commit is contained in:
@@ -151,7 +151,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
@@ -189,7 +189,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
backtest_config["benchmark"] = self.benchmark
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
|
||||
@@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
|
||||
83
qlib/backtest/signal.py
Normal file
83
qlib/backtest/signal.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Union
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Some trading strategy make decisions based on other prediction signals
|
||||
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
|
||||
|
||||
This
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
"""
|
||||
Signal With pandas with based Cache
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
2008-01-03 0.120125
|
||||
2008-01-04 0.878860
|
||||
2008-01-07 0.505539
|
||||
2008-01-08 0.395004
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
return signal
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
...
|
||||
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
if isinstance(pred_scores, pd.DataFrame):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
# TODO: this method is not included in the framework and could be refactor later
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .model_strategy import (
|
||||
from .signal_strategy import (
|
||||
TopkDropoutStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ This strategy is not well maintained
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .model_strategy import WeightStrategyBase
|
||||
from .signal_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
import copy
|
||||
from qlib.backtest.signal import ModelSignal, Signal, SignalWCache
|
||||
from typing import Union
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import BaseModel
|
||||
from qlib.backtest.position import Position
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
class TopkDropoutStrategy(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
topk,
|
||||
n_drop,
|
||||
model: BaseModel = None,
|
||||
dataset: Dataset = None,
|
||||
signal: Union[pd.DataFrame, pd.Series] = None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
@@ -64,7 +70,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
@@ -73,6 +79,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.risk_degree = risk_degree
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
assert signal is not None or dataset is not None and model is not None
|
||||
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
@@ -87,7 +95,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
if self.only_tradable:
|
||||
@@ -235,15 +243,17 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return TradeDecisionWO(sell_order_list + buy_order_list, self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
class WeightStrategyBase(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
model: BaseModel = None,
|
||||
dataset: Dataset = None,
|
||||
signal: Union[pd.DataFrame, pd.Series] = None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
@@ -260,12 +270,14 @@ class WeightStrategyBase(ModelStrategy):
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
assert signal is not None or dataset is not None and model is not None
|
||||
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
@@ -298,7 +310,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
@@ -17,7 +17,7 @@ from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -194,45 +194,6 @@ class BaseStrategy:
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : BaseModel
|
||||
the model used in when making predictions
|
||||
dataset : DatasetH
|
||||
provide test data for model
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
if isinstance(self.pred_scores, pd.DataFrame):
|
||||
self.pred_scores = self.pred_scores.iloc[:, 0]
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, pdate model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""RL-based strategy"""
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
|
||||
Reference in New Issue
Block a user