1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

move backtest to core, fix calendar bugs, add some docstring

This commit is contained in:
bxdd
2021-05-27 21:14:39 +08:00
parent 2ad61f12b3
commit 4085b447aa
27 changed files with 298 additions and 216 deletions

View File

@@ -10,7 +10,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.contrib.backtest import collect_data
from qlib.backtest import collect_data
class MultiLevelTradingWorkflow:
@@ -61,17 +61,17 @@ class MultiLevelTradingWorkflow:
}
trade_start_time = "2017-01-01"
trade_end_time = "2017-02-01"
trade_end_time = "2020-08-01"
port_analysis_config = {
"executor": {
"class": "SplitExecutor",
"module_path": "qlib.contrib.backtest.executor",
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "week",
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.contrib.backtest.executor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
"verbose": True,

View File

@@ -66,7 +66,6 @@
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.data.handler import Alpha158\n",
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",

View File

@@ -1,19 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord

View File

@@ -7,10 +7,10 @@ from .executor import BaseExecutor
from .backtest import backtest as backtest_func
from .backtest import collect_data as data_generator
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
from ..strategy.base import BaseStrategy
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
logger = get_module_logger("backtest caller")

View File

@@ -24,6 +24,8 @@ rtn & earning in the Account
**is consider cost**
while earning is the difference of two position value, so it considers cost, it is the true return rate
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
Now rtn has been removed in the hierarchical backtest implemention.
"""

View File

@@ -8,11 +8,11 @@ import logging
import numpy as np
import pandas as pd
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...config import C, REG_CN
from ...utils.resam import resam_ts_data
from ...log import get_module_logger
from ..data.data import D
from ..data.dataset.utils import get_level_index
from ..config import C, REG_CN
from ..utils.resam import resam_ts_data
from ..log import get_module_logger
from .order import Order
@@ -35,7 +35,7 @@ class Exchange:
"""__init__
:param freq: frequency of data
:param start_time: closed start time for backtest
:param start_time: closed start time for backtest
:param end_time: closed end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'

View File

@@ -3,8 +3,8 @@ import warnings
import pandas as pd
from typing import Union
from ...utils import init_instance_by_config
from ...utils.resam import parse_freq
from ..utils import init_instance_by_config
from ..utils.resam import parse_freq
from .order import Order
@@ -30,7 +30,7 @@ class BaseExecutor:
Parameters
----------
time_per_step : str
trade time per trading step, used for genreate trade calendar
trade time per trading step, used for genreate the trade calendar
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
@@ -80,16 +80,18 @@ class BaseExecutor:
if "start_time" in kwargs or "end_time" in kwargs:
start_time = kwargs.get("start_time")
end_time = kwargs.get("end_time")
self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time)
self.trade_calendar = TradeCalendarManager(
freq=self.time_per_step, start_time=start_time, end_time=end_time
)
if common_infra is not None:
self.reset_common_infra(common_infra)
def get_level_infra(self):
return {"calendar": self.calendar}
return {"trade_calendar": self.trade_calendar}
def finished(self):
return self.calendar.finished()
return self.trade_calendar.finished()
def execute(self, trade_decision):
"""execute the trade decision and return the executed result
@@ -117,8 +119,13 @@ class BaseExecutor:
raise NotImplementedError("get_report is not implemented!")
class SplitExecutor(BaseExecutor):
from ...strategy.base import BaseStrategy
class NestedExecutor(BaseExecutor):
"""
Nested Executor with inner strategy and executor
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
"""
from ..strategy.base import BaseStrategy
def __init__(
self,
@@ -127,10 +134,10 @@ class SplitExecutor(BaseExecutor):
inner_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: dict = {},
**kwargs,
):
@@ -153,7 +160,7 @@ class SplitExecutor(BaseExecutor):
inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
)
super(SplitExecutor, self).__init__(
super(NestedExecutor, self).__init__(
time_per_step=time_per_step,
start_time=start_time,
end_time=end_time,
@@ -173,7 +180,7 @@ class SplitExecutor(BaseExecutor):
- reset trade_exchange
- reset inner_strategyand inner_executor common infra
"""
super(SplitExecutor, self).reset_common_infra(common_infra)
super(NestedExecutor, self).reset_common_infra(common_infra)
if self.generate_report and "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
@@ -182,15 +189,15 @@ class SplitExecutor(BaseExecutor):
self.inner_strategy.reset_common_infra(common_infra)
def _init_sub_trading(self, trade_decision):
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.inner_executor.get_level_infra()
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
def _update_trade_account(self):
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
self.trade_account.update_bar_count()
if self.generate_report:
self.trade_account.update_bar_report(
@@ -200,7 +207,6 @@ class SplitExecutor(BaseExecutor):
)
def execute(self, trade_decision):
self.calendar.step()
self._init_sub_trading(trade_decision)
execute_result = []
_inner_execute_result = None
@@ -210,13 +216,13 @@ class SplitExecutor(BaseExecutor):
execute_result.extend(_inner_execute_result)
if hasattr(self, "trade_account"):
self._update_trade_account()
self.trade_calendar.step()
return execute_result
def collect_data(self, trade_decision):
if self.track_data:
yield trade_decision
self.calendar.step()
self.trade_calendar.step()
self._init_sub_trading(trade_decision)
execute_result = []
_inner_execute_result = None
@@ -240,15 +246,17 @@ class SplitExecutor(BaseExecutor):
class SimulatorExecutor(BaseExecutor):
"""Executor that simulate the true market"""
def __init__(
self,
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: dict = {},
**kwargs,
):
@@ -282,9 +290,9 @@ class SimulatorExecutor(BaseExecutor):
self.trade_exchange = common_infra.get("trade_exchange")
def execute(self, trade_decision):
self.calendar.step()
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
execute_result = []
for order in trade_decision:
if self.trade_exchange.check_order(order) is True:
@@ -333,7 +341,7 @@ class SimulatorExecutor(BaseExecutor):
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
)
self.trade_calendar.step()
return execute_result
def get_report(self):

View File

@@ -5,8 +5,8 @@
import numpy as np
import pandas as pd
from .position import Position
from ...data import D
from ...config import C
from ..data import D
from ..config import C
import datetime
from pathlib import Path

View File

@@ -10,8 +10,8 @@ import warnings
from pandas.core.frame import DataFrame
from ...utils.resam import parse_freq, resam_ts_data
from ...data import D
from ..utils.resam import parse_freq, resam_ts_data
from ..data import D
class Report:
@@ -86,9 +86,9 @@ class Report:
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
except ValueError:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
else:
raise ValueError(f"benchmark freq {freq} is not supported")
if len(_temp_result) == 0:

98
qlib/backtest/utils.py Normal file
View File

@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union
from ..utils.resam import get_resam_calendar
from ..data.data import Cal
class TradeCalendarManager:
"""
Manager for trading calendar
- BaseStrategy and BaseExecutor will use it
"""
def __init__(
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
freq : str
frequency of trading calendar, also trade time per trading step
start_time : Union[str, pd.Timestamp], optional
closed start of the trading calendar, by default None
If `start_time` is None, it must be reset before trading.
end_time : Union[str, pd.Timestamp], optional
closed end of the trade time range, by default None
If `end_time` is None, it must be reset before trading.
"""
self.freq = freq
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(end_time) if end_time else None
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
def _init_trade_calendar(self, freq, start_time, end_time):
"""
Reset the trade calendar
- self.trade_len : The total count for trading step
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
"""
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
self.trade_calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
self.trade_step = 0
def finished(self):
"""
Check if the trading finished
- Should check before calling strategy.generate_decisions and executor.execute
- If self.trade_step >= self.self.trade_len, it means the trading is finished
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
"""
return self.trade_step >= self.trade_len
def step(self):
if self.finished():
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
self.trade_step = self.trade_step + 1
def get_freq(self):
return self.freq
def get_trade_len(self):
return self.trade_len
def get_trade_step(self):
return self.trade_step
def get_step_time(self, trade_step=0, shift=0):
"""
Get the time range of trading step
Parameters
----------
trade_step : int, optional
the number of trading step finished, by default 0
shift : int, optional
shift bars , by default 0
Returns
-------
Tuple[pd.Timestamp, pd.Timestap]
- If shift == 0, return the trading time range
- If shift > 0, return the trading time range of the earlier shift bars
- If shift < 0, return the trading time range of the later shift bar
"""
trade_step = trade_step - shift
calendar_index = self.start_index + trade_step
return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
def get_all_time(self):
"""Get the start_time and end_time for trading"""
return self.start_time, self.end_time

View File

@@ -149,9 +149,9 @@ _default_config = {
"task_db_name": "default_task_db",
},
# Shift minute for highfreq minite data, used in backtest
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:59]
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:59] - shift*minute
"min_data_shift": {0},
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute
"min_data_shift": 0,
}
MODE_CONF = {

View File

@@ -1,67 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union
from ...utils.resam import get_resam_calendar
from ...data.data import Cal
class TradeCalendarManager:
"""
Manager for trading calendar
- BaseStrategy and BaseExecutor will use it
"""
def __init__(
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
freq : str
frequency of trading calendar, also trade time per trading step
start_time : Union[str, pd.Timestamp], optional
closed start of the trading calendar, by default None
If `start_time` is None, it must be reset before trading.
end_time : Union[str, pd.Timestamp], optional
closed end of the trade time range, by default None
If `end_time` is None, it must be reset before trading.
"""
self.freq = freq
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(start_time) if start_time else None
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
def _init_trade_calendar(self, freq, start_time, end_time):
"""reset trade calendar"""
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
self.calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
self.trade_index = 0
def finished(self):
return self.trade_index >= self.trade_len
def step(self):
if self.finished():
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
self.trade_index = self.trade_index + 1
def get_freq(self):
return self.freq
def get_trade_len(self):
return self.trade_len
def get_trade_index(self):
return self.trade_index
def get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)

View File

@@ -9,7 +9,7 @@ import numpy as np
import pandas as pd
import warnings
from ..log import get_module_logger
from .backtest import get_exchange, backtest as backtest_func
from ..backtest import get_exchange, backtest as backtest_func
from ..utils import get_date_range
from ..utils.resam import parse_freq
@@ -141,9 +141,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
whether to print log.
"""
warnings.warn(
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
)
warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
report_dict = backtest_func(
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
)

View File

@@ -4,7 +4,7 @@
import pandas as pd
from ...backtest.profit_attribution import get_stock_weight_df
from ....backtest.profit_attribution import get_stock_weight_df
def parse_position(position: dict = None) -> pd.DataFrame:

View File

@@ -97,7 +97,7 @@ def rank_label_graph(
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
:param position: position data; **qlib.backtest.backtest** result.
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
**The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.

View File

@@ -17,6 +17,7 @@ class SoftTopkStrategy(WeightStrategyBase):
max_sold_weight=1.0,
risk_degree=0.95,
buy_method="first_fill",
trade_exchange=None,
level_infra={},
common_infra={},
**kwargs,
@@ -31,14 +32,14 @@ class SoftTopkStrategy(WeightStrategyBase):
average_fill: assign the weight to the stocks rank high averagely.
"""
super(SoftTopkStrategy, self).__init__(
model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs
model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs
)
self.topk = topk
self.max_sold_weight = max_sold_weight
self.risk_degree = risk_degree
self.buy_method = buy_method
def get_risk_degree(self, trade_index=None):
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing

View File

@@ -5,7 +5,7 @@ import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ..backtest.order import Order
from ...backtest.order import Order
from .order_generator import OrderGenWInteract
@@ -21,6 +21,7 @@ class TopkDropoutStrategy(ModelStrategy):
risk_degree=0.95,
hold_thresh=1,
only_tradable=False,
trade_exchange=None,
level_infra={},
common_infra={},
**kwargs,
@@ -47,6 +48,9 @@ class TopkDropoutStrategy(ModelStrategy):
strategy will make buy sell decision without checking the tradable state of the stock.
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
@@ -58,6 +62,8 @@ class TopkDropoutStrategy(ModelStrategy):
self.risk_degree = risk_degree
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
@@ -73,7 +79,7 @@ class TopkDropoutStrategy(ModelStrategy):
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_index=None):
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing.
@@ -82,9 +88,10 @@ class TopkDropoutStrategy(ModelStrategy):
return self.risk_degree
def generate_trade_decision(self, execute_result=None):
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
@@ -179,7 +186,7 @@ class TopkDropoutStrategy(ModelStrategy):
continue
if code in sell:
# check hold limit
time_per_step = self.calendar.get_freq()
time_per_step = self.trade_calendar.get_freq()
if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
continue
# sell order
@@ -243,6 +250,7 @@ class WeightStrategyBase(ModelStrategy):
model,
dataset,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra={},
common_infra={},
**kwargs,
@@ -254,6 +262,8 @@ class WeightStrategyBase(ModelStrategy):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
@@ -269,7 +279,7 @@ class WeightStrategyBase(ModelStrategy):
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_index=None):
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing.
@@ -307,9 +317,11 @@ class WeightStrategyBase(ModelStrategy):
"""
# generate_trade_decision
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
@@ -320,7 +332,7 @@ class WeightStrategyBase(ModelStrategy):
order_list = self.order_generator.generate_order_list_from_target_weight_position(
current=current_temp,
trade_exchange=self.trade_exchange,
risk_degree=self.get_risk_degree(trade_index),
risk_degree=self.get_risk_degree(trade_step),
target_weight_position=target_weight_position,
pred_start_time=pred_start_time,
pred_end_time=pred_end_time,

View File

@@ -4,8 +4,8 @@
"""
This order generator is for strategies based on WeightStrategyBase
"""
from ..backtest.position import Position
from ..backtest.exchange import Exchange
from ...backtest.position import Position
from ...backtest.exchange import Exchange
import pandas as pd
import copy

View File

@@ -3,13 +3,35 @@ import warnings
from ...utils.resam import resam_ts_data
from ...data.data import D
from ...data.dataset.utils import convert_index_format
from ...strategy.base import RuleStrategy
from ..backtest.order import Order
from ...strategy.base import BaseStrategy
from ...backtest.order import Order
from ...backtest.exchange import Exchange
class TWAPStrategy(RuleStrategy):
class TWAPStrategy(BaseStrategy):
"""TWAP Strategy for trading"""
def __init__(
self,
outer_trade_decision: object = None,
trade_exchange: Exchange = None,
level_infra: dict = {},
common_infra: dict = {},
):
"""
Parameters
----------
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(TWAPStrategy, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
)
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
@@ -44,9 +66,11 @@ class TWAPStrategy(RuleStrategy):
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_index = self.calendar.get_trade_index()
trade_len = self.calendar.get_trade_len()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
order_list = []
for order in self.outer_trade_decision:
if not self.trade_exchange.is_stock_tradable(
@@ -57,21 +81,21 @@ class TWAPStrategy(RuleStrategy):
_order_amount = None
# consider trade unit
if _amount_trade_unit is None:
# split the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
# divide the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# split the order equally
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
# divide the order equally
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or trade_index == trade_len
_order_amount is None or trade_step == trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -89,7 +113,7 @@ class TWAPStrategy(RuleStrategy):
return order_list
class SBBStrategyBase(RuleStrategy):
class SBBStrategyBase(BaseStrategy):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
"""
@@ -98,6 +122,27 @@ class SBBStrategyBase(RuleStrategy):
TREND_SHORT = 1
TREND_LONG = 2
def __init__(
self,
outer_trade_decision: object = None,
trade_exchange: Exchange = None,
level_infra: dict = {},
common_infra: dict = {},
):
"""
Parameters
----------
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(SBBStrategyBase, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
)
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra is not None:
@@ -132,15 +177,17 @@ class SBBStrategyBase(RuleStrategy):
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_index = self.calendar.get_trade_index()
trade_len = self.calendar.get_trade_len()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
order_list = []
# for each order in in self.outer_trade_decision
for order in self.outer_trade_decision:
# predict the price trend
if trade_index % 2 == 1:
if trade_step % 2 == 0:
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
@@ -148,7 +195,7 @@ class SBBStrategyBase(RuleStrategy):
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
if trade_index % 2 == 1:
if trade_step % 2 == 0:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
continue
# get amount of one trade unit
@@ -157,21 +204,21 @@ class SBBStrategyBase(RuleStrategy):
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
# split the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
# divide the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# split the order equally
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
# divide the order equally
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
_order_amount = (
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or trade_index == trade_len
_order_amount is None or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -190,31 +237,31 @@ class SBBStrategyBase(RuleStrategy):
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
# N trade day last, split the order into N + 1 parts, and trade 2 parts
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
_order_amount = (
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2)
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# N trade day last, split the order into N + 1 parts, and trade 2 parts
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
_order_amount = (
(trade_unit_cnt + trade_len - trade_index + 1)
// (trade_len - trade_index + 2)
(trade_unit_cnt + trade_len - trade_step)
// (trade_len - trade_step + 1)
* 2
* _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
_order_amount is None or trade_index == trade_len
_order_amount is None or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
if _order_amount:
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
if trade_index % 2 == 1:
if trade_step % 2 == 0:
# in the first of two adjacent bar
# if look short on the price, sell the stock more
# if look long on the price, sell the stock more
@@ -253,7 +300,7 @@ class SBBStrategyBase(RuleStrategy):
)
order_list.append(_order)
if trade_index % 2 == 1:
if trade_step % 2 == 0:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
@@ -269,6 +316,7 @@ class SBBStrategyEMA(SBBStrategyBase):
outer_trade_decision=[],
instruments="csi300",
freq="day",
trade_exchange: Exchange = None,
level_infra={},
common_infra={},
**kwargs,
@@ -288,13 +336,13 @@ class SBBStrategyEMA(SBBStrategyBase):
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs)
def _reset_signal(self):
trade_len = self.calendar.get_trade_len()
trade_len = self.trade_calendar.get_trade_len()
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1)
signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
_, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
@@ -314,8 +362,8 @@ class SBBStrategyEMA(SBBStrategyBase):
else:
self.level_infra.update(level_infra)
if "calendar" in level_infra:
self.calendar = level_infra.get("calendar")
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):

View File

@@ -775,7 +775,7 @@ class ClientCalendarProvider(CalendarProvider):
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
self.conn.send_request(
request_type="calendar",
request_type="trade_calendar",
request_content={
"start_time": str(start_time),
"end_time": str(end_time),
@@ -990,7 +990,7 @@ class LocalProvider(BaseProvider):
:param type: The type of resource for the uri
:param **kwargs:
"""
if type == "calendar":
if type == "trade_calendar":
return Cal._uri(**kwargs)
elif type == "instrument":
return Inst._uri(**kwargs)

View File

@@ -3,8 +3,9 @@
from typing import Union
from ..backtest.executor import BaseExecutor
from .interpreter import StateInterpreter, ActionInterpreter
from ..contrib.backtest.executor import BaseExecutor
from ..utils import init_instance_by_config
from .interpreter import BaseInterpreter

View File

@@ -1,15 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import pandas as pd
from typing import List, Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..contrib.backtest.order import Order
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
@@ -44,8 +38,8 @@ class BaseStrategy:
else:
self.level_infra.update(level_infra)
if "calendar" in level_infra:
self.calendar = level_infra.get("calendar")
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
def reset_common_infra(self, common_infra):
if not hasattr(self, "common_infra"):
@@ -83,12 +77,6 @@ class BaseStrategy:
raise NotImplementedError("generate_trade_decision is not implemented!")
class RuleStrategy(BaseStrategy):
"""Rule-based Trading strategy"""
pass
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""

View File

@@ -40,7 +40,7 @@ def parse_freq(freq: str) -> Tuple[int, str]:
raise ValueError(
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
)
_count = int(match_obj.group(1)) if match_obj.group(1) is None else 1
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
_freq = match_obj.group(2)
_freq_format_dict = {
"month": "month",
@@ -58,7 +58,8 @@ def parse_freq(freq: str) -> Tuple[int, str]:
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
"""
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
Assumption: The fix length (240) of the calendar in each day.
Assumption:
- Fix length (240) of the calendar in each day.
Parameters
----------
@@ -83,16 +84,19 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
if freq_sam == "minute":
def cal_sam_minute(x, sam_minutes):
"""
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
"""
day_time = pd.Timestamp(x.date())
shift = C.min_data_shift
# shift represents the shift minute the market time
# - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
# - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
# - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)]
# - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1)
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
if open_time <= x <= mid_close_time:
@@ -101,7 +105,6 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
minute_index = (x - mid_open_time).seconds // 60 + 120
else:
raise ValueError("datetime of calendar is out of range")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
@@ -109,7 +112,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
elif 120 <= minute_index < 240:
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
else:
raise ValueError("calendar minute_index error")
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
if freq_raw != "minute":
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
@@ -189,11 +192,13 @@ def get_resam_calendar(
freq = "day"
except ValueError:
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
freq = "min"
elif norm_freq == "minute":
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future)
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")

View File

@@ -8,10 +8,10 @@ import pandas as pd
from pathlib import Path
from pprint import pprint
from ..contrib.evaluate import risk_analysis
from ..contrib.backtest import backtest as normal_backtest
from ..data.dataset import DatasetH
from ..data.dataset.handler import DataHandlerLP
from ..backtest import backtest as normal_backtest
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict