1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Merge branch 'nested_decision_exe' into rl-dummy

This commit is contained in:
Yuge Zhang
2021-06-01 11:34:45 +08:00
13 changed files with 301 additions and 132 deletions

View File

@@ -1,24 +0,0 @@
# Multi-level Trading
This worflow is an example for multi-level trading.
## Introduction
Qlib supports backtesting of various strategies, including portfolio management strategies, order split strategies, model-based strategies (such as deep learning models), rule-based strategies, and RL-based strategies.
And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies.
This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders.
## Usage
Start backtesting by running the following command:
```bash
python workflow.py backtest
```
Start collecting data by running the following command:
```bash
python workflow.py collect_data
```

View File

@@ -0,0 +1,30 @@
# Nested Decision Execution
This worflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
## Weekly Portfolio Generation and Daily Order Execution
This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency.
### Usage
Start backtesting by running the following command:
```bash
python workflow.py backtest
```
Start collecting data by running the following command:
```bash
python workflow.py collect_data
```
## Daily Portfolio Generation and Minutely Order Execution
This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency.
### Usage
Start backtesting by running the following command:
```bash
python workflow.py backtest_highfreq
```

View File

@@ -5,8 +5,8 @@ from typing import Optional
import qlib
import fire
from qlib.config import REG_CN
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.data import D
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
@@ -14,14 +14,14 @@ from qlib.tests.data import GetData
from qlib.backtest import collect_data
class MultiLevelTradingWorkflow:
class NestedDecisonExecutionWorkflow:
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"end_time": "2021-01-20",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
@@ -55,15 +55,12 @@ class MultiLevelTradingWorkflow:
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"test": ("2017-01-01", "2021-01-20"),
},
},
},
}
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
port_analysis_config = {
"executor": {
"class": "NestedExecutor",
@@ -87,12 +84,13 @@ class MultiLevelTradingWorkflow:
"instruments": market,
},
},
"generate_report": True,
"track_data": True,
},
},
"backtest": {
"start_time": trade_start_time,
"end_time": trade_end_time,
"start_time": "2017-01-01",
"end_time": "2020-08-01",
"account": 100000000,
"benchmark": benchmark,
"exchange_kwargs": {
@@ -174,6 +172,98 @@ class MultiLevelTradingWorkflow:
for trade_decision in data_generator:
print(trade_decision)
def _init_qlib_with_backend(self):
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri_1min):
print(f"Qlib data is not found in {provider_uri_1min}")
GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
# TODO: update latest data
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri_day):
print(f"Qlib data is not found in {provider_uri_day}")
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN)
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
client_config = {
"calendar_provider": {
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
"feature_provider": {
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
}
qlib.init(provider_uri=provider_uri_day, **client_config)
def _get_highfreq_config(self, model, dataset):
executor_config = self.port_analysis_config["executor"]
# update executor with hierarchical decison freq ["day", "1min"]
executor_config["kwargs"]["time_per_step"] = "day"
executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "15min"
backtest_config = self.port_analysis_config["backtest"]
# yahoo highfreq data time
backtest_config["start_time"] = "2020-09-20"
backtest_config["end_time"] = "2021-01-20"
# update benchmark, yahoo data don't have SH000300
instruments = D.instruments(market="csi300")
instrument_list = D.list_instruments(instruments=instruments, as_list=True)
backtest_config["benchmark"] = instrument_list
# update exchange config
backtest_config["exchange_kwargs"]["freq"] = "1min"
# set strategy
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
}
return executor_config, strategy_config, backtest_config
def backtest_highfreq(self):
self._init_qlib_with_backend()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
self._train_model(model, dataset)
executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset)
highfreq_port_analysis_config = {
"executor": executor_config,
"strategy": strategy_config,
"backtest": backtest_config,
}
with R.start(experiment_name="backtest_highfreq"):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day")
par.generate()
if __name__ == "__main__":
fire.Fire(MultiLevelTradingWorkflow)
fire.Fire(NestedDecisonExecutionWorkflow)

View File

@@ -7,6 +7,7 @@ from .executor import BaseExecutor
from .backtest import backtest as backtest_func
from .backtest import collect_data as data_generator
from .utils import CommonInfrastructure
from ..strategy.base import BaseStrategy
from ..utils import init_instance_by_config
from ..log import get_module_logger
@@ -101,10 +102,7 @@ def get_strategy_executor(
)
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = {
"trade_account": trade_account,
"trade_exchange": trade_exchange,
}
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)

View File

@@ -9,7 +9,7 @@ from ..utils.resam import parse_freq
from .order import Order
from .exchange import Exchange
from .utils import TradeCalendarManager
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
class BaseExecutor:
@@ -23,7 +23,7 @@ class BaseExecutor:
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: dict = {},
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
@@ -39,7 +39,7 @@ class BaseExecutor:
whether to generate trade_decision, will be used when making data for multi-level training
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
- Else, `trade_decision` will not be generated
common_infra : dict, optional:
common_infra : CommonInfrastructure, optional:
common infrastructure for backtesting, may including:
- trade_account : Account, optional
trade account for trading
@@ -63,11 +63,11 @@ class BaseExecutor:
else:
self.common_infra.update(common_infra)
if "trade_account" in common_infra:
if common_infra.has("trade_account"):
self.trade_account = copy.copy(common_infra.get("trade_account"))
self.trade_account.reset(freq=self.time_per_step, init_report=True)
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs):
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `track_data`, used when making data for multi-level training
@@ -88,7 +88,7 @@ class BaseExecutor:
self.reset_common_infra(common_infra)
def get_level_infra(self):
return {"trade_calendar": self.trade_calendar}
return LevelInfrastructure(trade_calendar=self.trade_calendar)
def finished(self):
return self.trade_calendar.finished()
@@ -138,7 +138,7 @@ class NestedExecutor(BaseExecutor):
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: dict = {},
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
@@ -182,7 +182,7 @@ class NestedExecutor(BaseExecutor):
"""
super(NestedExecutor, self).reset_common_infra(common_infra)
if self.generate_report and "trade_exchange" in common_infra:
if self.generate_report and common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
self.inner_executor.reset_common_infra(common_infra)
@@ -257,7 +257,7 @@ class SimulatorExecutor(BaseExecutor):
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: dict = {},
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
@@ -286,7 +286,7 @@ class SimulatorExecutor(BaseExecutor):
- reset trade_exchange
"""
super(SimulatorExecutor, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def execute(self, trade_decision):
@@ -304,7 +304,7 @@ class SimulatorExecutor(BaseExecutor):
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
"[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
@@ -316,7 +316,7 @@ class SimulatorExecutor(BaseExecutor):
)
else:
print(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
"[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,

View File

@@ -80,12 +80,12 @@ class Report:
fields = ["$close/Ref($close,1)-1"]
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
except ValueError:
except (ValueError, KeyError):
_, norm_freq = parse_freq(freq)
if norm_freq in ["month", "week", "day"]:
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
except ValueError:
except (ValueError, KeyError):
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import pandas as pd
import warnings
from typing import Union
from ..utils.resam import get_resam_calendar
@@ -41,7 +42,7 @@ class TradeCalendarManager:
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
"""
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
self.trade_calendar = _calendar
self._calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
self.end_index = _end_index
@@ -91,8 +92,51 @@ class TradeCalendarManager:
"""
trade_step = trade_step - shift
calendar_index = self.start_index + trade_step
return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
def get_all_time(self):
"""Get the start_time and end_time for trading"""
return self.start_time, self.end_time
class BaseInfrastructure:
def __init__(self, **kwargs):
self.reset_infra(**kwargs)
def get_support_infra(self):
raise NotImplementedError("`get_support_infra` is not implemented!")
def reset_infra(self, **kwargs):
support_infra = self.get_support_infra()
for k, v in kwargs.items():
if k in support_infra:
setattr(self, k, v)
else:
warnings.warn(f"{k} is ignored in `reset_infra`!")
def get(self, infra_name):
if hasattr(self, infra_name):
return getattr(self, infra_name)
else:
warnings.warn(f"infra {infra_name} is not found!")
def has(self, infra_name):
if infra_name in self.get_support_infra() and hasattr(self, infra_name):
return True
else:
return False
def update(self, other):
support_infra = other.get_support_infra()
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
self.reset_infra(**infra_dict)
class CommonInfrastructure(BaseInfrastructure):
def get_support_infra(self):
return ["trade_account", "trade_exchange"]
class LevelInfrastructure(BaseInfrastructure):
def get_support_infra(self):
return ["trade_calendar"]

View File

@@ -18,8 +18,8 @@ class SoftTopkStrategy(WeightStrategyBase):
risk_degree=0.95,
buy_method="first_fill",
trade_exchange=None,
level_infra={},
common_infra={},
level_infra=None,
common_infra=None,
**kwargs,
):
"""Parameter

View File

@@ -22,8 +22,8 @@ class TopkDropoutStrategy(ModelStrategy):
hold_thresh=1,
only_tradable=False,
trade_exchange=None,
level_infra={},
common_infra={},
level_infra=None,
common_infra=None,
**kwargs,
):
"""
@@ -76,7 +76,7 @@ class TopkDropoutStrategy(ModelStrategy):
"""
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_step=None):
@@ -177,8 +177,6 @@ class TopkDropoutStrategy(ModelStrategy):
# Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]
# print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
# print("flag", len(sell), len(buy), self.topk, len(last))
for code in current_stock_list:
if not self.trade_exchange.is_stock_tradable(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
@@ -251,8 +249,8 @@ class WeightStrategyBase(ModelStrategy):
dataset,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra={},
common_infra={},
level_infra=None,
common_infra=None,
**kwargs,
):
super(WeightStrategyBase, self).__init__(
@@ -276,7 +274,7 @@ class WeightStrategyBase(ModelStrategy):
"""
super(WeightStrategyBase, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_step=None):

View File

@@ -1,4 +1,5 @@
import warnings
from typing import List, Union
from ...utils.resam import resam_ts_data
from ...data.data import D
@@ -6,6 +7,7 @@ from ...data.dataset.utils import convert_index_format
from ...strategy.base import BaseStrategy
from ...backtest.order import Order
from ...backtest.exchange import Exchange
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
class TWAPStrategy(BaseStrategy):
@@ -13,17 +15,20 @@ class TWAPStrategy(BaseStrategy):
def __init__(
self,
outer_trade_decision: object = None,
outer_trade_decision: List[Order] = None,
trade_exchange: Exchange = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : List[Order]
the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(TWAPStrategy, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
@@ -36,24 +41,24 @@ class TWAPStrategy(BaseStrategy):
"""
Parameters
----------
common_infra : dict, optional
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(TWAPStrategy, self).reset_common_infra(common_infra)
if common_infra is not None:
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: object = None, **kwargs):
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
"""
Parameters
----------
outer_trade_decision : object, optional
outer_trade_decision : List[Order], optional
"""
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs)
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_amount = {}
for order in outer_trade_decision:
@@ -73,21 +78,24 @@ class TWAPStrategy(BaseStrategy):
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
order_list = []
for order in self.outer_trade_decision:
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
continue
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_order_amount = None
# consider trade unit
# considering trade unit
if _amount_trade_unit is None:
# divide the order equally
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# divide the order equally
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
_order_amount = (
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
)
@@ -124,14 +132,16 @@ class SBBStrategyBase(BaseStrategy):
def __init__(
self,
outer_trade_decision: object = None,
outer_trade_decision: List[Order] = None,
trade_exchange: Exchange = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
----------
outer_trade_decision : List[Order]
the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
@@ -144,21 +154,24 @@ class SBBStrategyBase(BaseStrategy):
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra is not None:
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision=None, **kwargs):
"""
Parameters
----------
outer_trade_decision : object, optional
common_infra : None, optional
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra.has("trade_exchange"):
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
"""
Parameters
----------
outer_trade_decision : List[Order], optional
"""
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_trend = {}
@@ -186,10 +199,12 @@ class SBBStrategyBase(BaseStrategy):
order_list = []
# for each order in in self.outer_trade_decision
for order in self.outer_trade_decision:
# predict the price trend
# get the price trend
if trade_step % 2 == 0:
# in the first of two adjacent bars, predict the price trend
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
# in the second of two adjacent bars, use the trend predicted in the first one
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
@@ -204,13 +219,14 @@ class SBBStrategyBase(BaseStrategy):
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
# divide the order equally
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# divide the order equally
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
_order_amount = (
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
@@ -262,9 +278,9 @@ class SBBStrategyBase(BaseStrategy):
if _order_amount:
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
if trade_step % 2 == 0:
# in the first of two adjacent bar
# in the first one of two adjacent bars
# if look short on the price, sell the stock more
# if look long on the price, sell the stock more
# if look long on the price, buy the stock more
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.SELL
@@ -281,7 +297,7 @@ class SBBStrategyBase(BaseStrategy):
)
order_list.append(_order)
else:
# in the second of two adjacent bar
# in the second one of two adjacent bars
# if look short on the price, buy the stock more
# if look long on the price, sell the stock more
if (
@@ -301,6 +317,7 @@ class SBBStrategyBase(BaseStrategy):
order_list.append(_order)
if trade_step % 2 == 0:
# in the first one of two adjacent bars, store the trend for the second one to use
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
@@ -313,22 +330,22 @@ class SBBStrategyEMA(SBBStrategyBase):
def __init__(
self,
outer_trade_decision=[],
instruments="csi300",
freq="day",
outer_trade_decision: List[Order] = None,
instruments: Union[List, str] = "csi300",
freq: str = "day",
trade_exchange: Exchange = None,
level_infra={},
common_infra={},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
instruments : str, optional
instruments : Union[List, str], optional
instruments of EMA signal, by default "csi300"
freq : str, optional
freq of EMA signal, by default "day"
Note: `freq` may be different from `steb_bar`
Note: `freq` may be different from `time_per_step`
"""
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
@@ -349,8 +366,10 @@ class SBBStrategyEMA(SBBStrategyBase):
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
if not signal_df.empty:
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
def reset_level_infra(self, level_infra):
"""
@@ -362,21 +381,24 @@ class SBBStrategyEMA(SBBStrategyBase):
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
if level_infra.has("trade_calendar"):
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
# if no signal, return mid trend
if stock_id not in self.signal:
return self.TREND_MID
else:
_sample_signal = resam_ts_data(
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
)
# if EMA signal == 0 or None, return mid trend
if _sample_signal is None or _sample_signal.iloc[0] == 0:
return self.TREND_MID
# if EMA signal > 0, return long trend
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
# if EMA signal > 0, return short trend
else:
return self.TREND_SHORT

View File

@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure
__all__ = ['BaseStrategy', 'ModelStrategy', 'RLStrategy', 'RLIntStrategy']
@@ -16,8 +18,8 @@ class BaseStrategy:
def __init__(
self,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
):
"""
Parameters
@@ -26,9 +28,9 @@ class BaseStrategy:
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
- If the strategy is used to split trade decison, it will be used
- If the strategy is used for portfolio management, it can be ignored
level_infra : dict, optional
level_infra : LevelInfrastructure, optional
level shared infrastructure for backtesting, including trade calendar
common_infra : dict, optional
common_infra : CommonInfrastructure, optional
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
"""
@@ -40,7 +42,7 @@ class BaseStrategy:
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
if level_infra.has("trade_calendar"):
self.trade_calendar = level_infra.get("trade_calendar")
def reset_common_infra(self, common_infra):
@@ -49,10 +51,16 @@ class BaseStrategy:
else:
self.common_infra.update(common_infra)
if "trade_account" in common_infra:
if common_infra.has("trade_account"):
self.trade_position = common_infra.get("trade_account").current
def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs):
def reset(
self,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision=None,
**kwargs,
):
"""
- reset `level_infra`, used to reset trade calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -87,8 +95,8 @@ class ModelStrategy(BaseStrategy):
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
@@ -123,8 +131,8 @@ class RLStrategy(BaseStrategy):
self,
policy,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
@@ -143,19 +151,19 @@ class RLIntStrategy(RLStrategy):
def __init__(
self,
policy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
state_interpreter : StateInterpreter
interpretor that interprets the qlib execute result into rl env state.
action_interpreter : ActionInterpreter
state_interpreter : Union[dict, StateInterpreter]
interpretor that interprets the qlib execute result into rl env state
action_interpreter : Union[dict, ActionInterpreter]
interpretor that interprets the rl agent action into qlib order list
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
@@ -165,8 +173,8 @@ class RLIntStrategy(RLStrategy):
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def generate_trade_decision(self, execute_result=None):
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)

View File

@@ -182,7 +182,7 @@ def get_resam_calendar(
try:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, future=future)
freq, freq_sam = freq, None
except ValueError:
except (ValueError, KeyError):
freq_sam = freq
if norm_freq in ["month", "week", "day"]:
try:
@@ -190,16 +190,16 @@ def get_resam_calendar(
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
)
freq = "day"
except ValueError:
except (ValueError, KeyError):
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
freq = "min"
freq = "1min"
elif norm_freq == "minute":
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
freq = "min"
freq = "1min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
@@ -288,11 +288,13 @@ def resam_ts_data(
from ..data.dataset.utils import get_level_index
feature = lazy_sort_index(ts_feature)
datetime_level = get_level_index(feature, level="datetime") == 0
if datetime_level:
feature = feature.loc[selector_datetime]
else:
feature = feature.loc[(slice(None), selector_datetime)]
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):

View File

@@ -46,3 +46,4 @@ def experiment_kill_signal_handler(signum, frame):
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
"""
R.end_exp(recorder_status=Recorder.STATUS_FA)
raise KeyboardInterrupt