1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Make static prediction easier

This commit is contained in:
Young
2021-10-15 11:21:03 +00:00
parent 2e49a5f7c0
commit ac08468330
9 changed files with 115 additions and 57 deletions

View File

@@ -151,7 +151,7 @@ class NestedDecisionExecutionWorkflow:
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
@@ -189,7 +189,7 @@ class NestedDecisionExecutionWorkflow:
backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,

View File

@@ -31,7 +31,7 @@ if __name__ == "__main__":
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,

83
qlib/backtest/signal.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..model.base import BaseModel
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..utils.resam import resam_ts_data
import pandas as pd
import abc
class Signal(metaclass=abc.ABCMeta):
"""
Some trading strategy make decisions based on other prediction signals
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
This
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
Returns
-------
Union[pd.Series, pd.DataFrame, None]:
returns None if no signal in the specific day
"""
...
class SignalWCache(Signal):
"""
Signal With pandas with based Cache
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
2008-01-03 0.120125
2008-01-04 0.878860
2008-01-07 0.505539
2008-01-08 0.395004
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
return signal
class ModelSignal(SignalWCache):
...
def __init__(self, model: BaseModel, dataset: Dataset):
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
if isinstance(pred_scores, pd.DataFrame):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
# TODO: this method is not included in the framework and could be refactor later
raise NotImplementedError("_update_model is not implemented!")

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .model_strategy import (
from .signal_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
)

View File

@@ -6,7 +6,7 @@ This strategy is not well maintained
from .order_generator import OrderGenWInteract
from .model_strategy import WeightStrategyBase
from .signal_strategy import WeightStrategyBase
import copy

View File

@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import warnings
import numpy as np

View File

@@ -1,27 +1,33 @@
import copy
from qlib.backtest.signal import ModelSignal, Signal, SignalWCache
from typing import Union
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.backtest.position import Position
import warnings
import numpy as np
import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...strategy.base import BaseStrategy
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
from .order_generator import OrderGenWInteract
class TopkDropoutStrategy(ModelStrategy):
class TopkDropoutStrategy(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
topk,
n_drop,
model: BaseModel = None,
dataset: Dataset = None,
signal: Union[pd.DataFrame, pd.Series] = None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
@@ -64,7 +70,7 @@ class TopkDropoutStrategy(ModelStrategy):
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop
@@ -73,6 +79,8 @@ class TopkDropoutStrategy(ModelStrategy):
self.risk_degree = risk_degree
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
assert signal is not None or dataset is not None and model is not None
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
@@ -87,7 +95,7 @@ class TopkDropoutStrategy(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
if self.only_tradable:
@@ -235,15 +243,17 @@ class TopkDropoutStrategy(ModelStrategy):
return TradeDecisionWO(sell_order_list + buy_order_list, self)
class WeightStrategyBase(ModelStrategy):
class WeightStrategyBase(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
model: BaseModel = None,
dataset: Dataset = None,
signal: Union[pd.DataFrame, pd.Series] = None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra=None,
@@ -260,12 +270,14 @@ class WeightStrategyBase(ModelStrategy):
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
assert signal is not None or dataset is not None and model is not None
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
@@ -298,7 +310,7 @@ class WeightStrategyBase(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
current_temp = copy.deepcopy(self.trade_position)

View File

@@ -17,7 +17,7 @@ from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.decision import BaseTradeDecision
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
class BaseStrategy:
@@ -194,45 +194,6 @@ class BaseStrategy:
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""
def __init__(
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
if isinstance(self.pred_scores, pd.DataFrame):
self.pred_scores = self.pred_scores.iloc[:, 0]
def _update_model(self):
"""
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")
class RLStrategy(BaseStrategy):
"""RL-based strategy"""

View File

@@ -144,7 +144,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,