mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
move backtest to core, fix calendar bugs, add some docstring
This commit is contained in:
@@ -10,7 +10,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.contrib.backtest import collect_data
|
||||
from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class MultiLevelTradingWorkflow:
|
||||
@@ -61,17 +61,17 @@ class MultiLevelTradingWorkflow:
|
||||
}
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2017-02-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
"class": "SplitExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "week",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"verbose": True,
|
||||
|
||||
@@ -66,7 +66,6 @@
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
@@ -7,10 +7,10 @@ from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
from .backtest import collect_data as data_generator
|
||||
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
@@ -24,6 +24,8 @@ rtn & earning in the Account
|
||||
**is consider cost**
|
||||
while earning is the difference of two position value, so it considers cost, it is the true return rate
|
||||
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
|
||||
|
||||
Now rtn has been removed in the hierarchical backtest implemention.
|
||||
"""
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...config import C, REG_CN
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...log import get_module_logger
|
||||
from ..data.data import D
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data
|
||||
from ..log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class Exchange:
|
||||
"""__init__
|
||||
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
:param start_time: closed start time for backtest
|
||||
:param end_time: closed end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
@@ -3,8 +3,8 @@ import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.resam import parse_freq
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.resam import parse_freq
|
||||
|
||||
|
||||
from .order import Order
|
||||
@@ -30,7 +30,7 @@ class BaseExecutor:
|
||||
Parameters
|
||||
----------
|
||||
time_per_step : str
|
||||
trade time per trading step, used for genreate trade calendar
|
||||
trade time per trading step, used for genreate the trade calendar
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
@@ -80,16 +80,18 @@ class BaseExecutor:
|
||||
if "start_time" in kwargs or "end_time" in kwargs:
|
||||
start_time = kwargs.get("start_time")
|
||||
end_time = kwargs.get("end_time")
|
||||
self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time)
|
||||
self.trade_calendar = TradeCalendarManager(
|
||||
freq=self.time_per_step, start_time=start_time, end_time=end_time
|
||||
)
|
||||
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
return {"calendar": self.calendar}
|
||||
return {"trade_calendar": self.trade_calendar}
|
||||
|
||||
def finished(self):
|
||||
return self.calendar.finished()
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision):
|
||||
"""execute the trade decision and return the executed result
|
||||
@@ -117,8 +119,13 @@ class BaseExecutor:
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
|
||||
|
||||
class SplitExecutor(BaseExecutor):
|
||||
from ...strategy.base import BaseStrategy
|
||||
class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
Nested Executor with inner strategy and executor
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
|
||||
"""
|
||||
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -127,10 +134,10 @@ class SplitExecutor(BaseExecutor):
|
||||
inner_strategy: Union[BaseStrategy, dict],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
@@ -153,7 +160,7 @@ class SplitExecutor(BaseExecutor):
|
||||
inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
|
||||
)
|
||||
|
||||
super(SplitExecutor, self).__init__(
|
||||
super(NestedExecutor, self).__init__(
|
||||
time_per_step=time_per_step,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
@@ -173,7 +180,7 @@ class SplitExecutor(BaseExecutor):
|
||||
- reset trade_exchange
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
"""
|
||||
super(SplitExecutor, self).reset_common_infra(common_infra)
|
||||
super(NestedExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if self.generate_report and "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
@@ -182,15 +189,15 @@ class SplitExecutor(BaseExecutor):
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
|
||||
|
||||
def _update_trade_account(self):
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_count()
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
@@ -200,7 +207,6 @@ class SplitExecutor(BaseExecutor):
|
||||
)
|
||||
|
||||
def execute(self, trade_decision):
|
||||
self.calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_result = []
|
||||
_inner_execute_result = None
|
||||
@@ -210,13 +216,13 @@ class SplitExecutor(BaseExecutor):
|
||||
execute_result.extend(_inner_execute_result)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
|
||||
self.trade_calendar.step()
|
||||
return execute_result
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
self.calendar.step()
|
||||
self.trade_calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_result = []
|
||||
_inner_execute_result = None
|
||||
@@ -240,15 +246,17 @@ class SplitExecutor(BaseExecutor):
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
"""Executor that simulate the true market"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_per_step: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
@@ -282,9 +290,9 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def execute(self, trade_decision):
|
||||
self.calendar.step()
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
for order in trade_decision:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
@@ -333,7 +341,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
self.trade_calendar.step()
|
||||
return execute_result
|
||||
|
||||
def get_report(self):
|
||||
@@ -5,8 +5,8 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .position import Position
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,8 +10,8 @@ import warnings
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ...utils.resam import parse_freq, resam_ts_data
|
||||
from ...data import D
|
||||
from ..utils.resam import parse_freq, resam_ts_data
|
||||
from ..data import D
|
||||
|
||||
|
||||
class Report:
|
||||
@@ -86,9 +86,9 @@ class Report:
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
|
||||
except ValueError:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
if len(_temp_result) == 0:
|
||||
98
qlib/backtest/utils.py
Normal file
98
qlib/backtest/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
from ..data.data import Cal
|
||||
|
||||
|
||||
class TradeCalendarManager:
|
||||
"""
|
||||
Manager for trading calendar
|
||||
- BaseStrategy and BaseExecutor will use it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of trading calendar, also trade time per trading step
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
closed start of the trading calendar, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def _init_trade_calendar(self, freq, start_time, end_time):
|
||||
"""
|
||||
Reset the trade calendar
|
||||
- self.trade_len : The total count for trading step
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
self.trade_calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_step = 0
|
||||
|
||||
def finished(self):
|
||||
"""
|
||||
Check if the trading finished
|
||||
- Should check before calling strategy.generate_decisions and executor.execute
|
||||
- If self.trade_step >= self.self.trade_len, it means the trading is finished
|
||||
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_step = self.trade_step + 1
|
||||
|
||||
def get_freq(self):
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_step(self):
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step=0, shift=0):
|
||||
"""
|
||||
Get the time range of trading step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_step : int, optional
|
||||
the number of trading step finished, by default 0
|
||||
shift : int, optional
|
||||
shift bars , by default 0
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[pd.Timestamp, pd.Timestap]
|
||||
- If shift == 0, return the trading time range
|
||||
- If shift > 0, return the trading time range of the earlier shift bars
|
||||
- If shift < 0, return the trading time range of the later shift bar
|
||||
"""
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
|
||||
|
||||
def get_all_time(self):
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
@@ -149,9 +149,9 @@ _default_config = {
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
# Shift minute for highfreq minite data, used in backtest
|
||||
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:59]
|
||||
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:59] - shift*minute
|
||||
"min_data_shift": {0},
|
||||
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]
|
||||
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute
|
||||
"min_data_shift": 0,
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ...utils.resam import get_resam_calendar
|
||||
from ...data.data import Cal
|
||||
|
||||
|
||||
class TradeCalendarManager:
|
||||
"""
|
||||
Manager for trading calendar
|
||||
- BaseStrategy and BaseExecutor will use it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of trading calendar, also trade time per trading step
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
closed start of the trading calendar, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(start_time) if start_time else None
|
||||
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def _init_trade_calendar(self, freq, start_time, end_time):
|
||||
"""reset trade calendar"""
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
self.calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_index = 0
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
def get_freq(self):
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_index(self):
|
||||
return self.trade_index
|
||||
|
||||
def get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from ..log import get_module_logger
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from ..backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.resam import parse_freq
|
||||
|
||||
@@ -141,9 +141,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
|
||||
)
|
||||
warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
|
||||
report_dict = backtest_func(
|
||||
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
from ...backtest.profit_attribution import get_stock_weight_df
|
||||
from ....backtest.profit_attribution import get_stock_weight_df
|
||||
|
||||
|
||||
def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
|
||||
@@ -97,7 +97,7 @@ def rank_label_graph(
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
:param position: position data; **qlib.backtest.backtest** result.
|
||||
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
|
||||
**The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class SoftTopkStrategy(WeightStrategyBase):
|
||||
max_sold_weight=1.0,
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
@@ -31,14 +32,14 @@ class SoftTopkStrategy(WeightStrategyBase):
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
"""
|
||||
super(SoftTopkStrategy, self).__init__(
|
||||
model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs
|
||||
model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.max_sold_weight = max_sold_weight
|
||||
self.risk_degree = risk_degree
|
||||
self.buy_method = buy_method
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
|
||||
@@ -5,7 +5,7 @@ import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ..backtest.order import Order
|
||||
from ...backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
risk_degree=0.95,
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
@@ -47,6 +48,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
@@ -58,6 +62,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.risk_degree = risk_degree
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -73,7 +79,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
@@ -82,9 +88,10 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return self.risk_degree
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
@@ -179,7 +186,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
time_per_step = self.calendar.get_freq()
|
||||
time_per_step = self.trade_calendar.get_freq()
|
||||
if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
|
||||
continue
|
||||
# sell order
|
||||
@@ -243,6 +250,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
model,
|
||||
dataset,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
@@ -254,6 +262,8 @@ class WeightStrategyBase(ModelStrategy):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -269,7 +279,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
@@ -307,9 +317,11 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
# generate_trade_decision
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
@@ -320,7 +332,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
order_list = self.order_generator.generate_order_list_from_target_weight_position(
|
||||
current=current_temp,
|
||||
trade_exchange=self.trade_exchange,
|
||||
risk_degree=self.get_risk_degree(trade_index),
|
||||
risk_degree=self.get_risk_degree(trade_step),
|
||||
target_weight_position=target_weight_position,
|
||||
pred_start_time=pred_start_time,
|
||||
pred_end_time=pred_end_time,
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
"""
|
||||
This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ..backtest.position import Position
|
||||
from ..backtest.exchange import Exchange
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
|
||||
@@ -3,13 +3,35 @@ import warnings
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import RuleStrategy
|
||||
from ..backtest.order import Order
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy):
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
"""TWAP Strategy for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
@@ -44,9 +66,11 @@ class TWAPStrategy(RuleStrategy):
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
@@ -57,21 +81,21 @@ class TWAPStrategy(RuleStrategy):
|
||||
_order_amount = None
|
||||
# consider trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# split the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
|
||||
# divide the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# split the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
|
||||
# divide the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or trade_index == trade_len
|
||||
_order_amount is None or trade_step == trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -89,7 +113,7 @@ class TWAPStrategy(RuleStrategy):
|
||||
return order_list
|
||||
|
||||
|
||||
class SBBStrategyBase(RuleStrategy):
|
||||
class SBBStrategyBase(BaseStrategy):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
"""
|
||||
@@ -98,6 +122,27 @@ class SBBStrategyBase(RuleStrategy):
|
||||
TREND_SHORT = 1
|
||||
TREND_LONG = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
super(SBBStrategyBase, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
super(SBBStrategyBase, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
@@ -132,15 +177,17 @@ class SBBStrategyBase(RuleStrategy):
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
order_list = []
|
||||
# for each order in in self.outer_trade_decision
|
||||
for order in self.outer_trade_decision:
|
||||
# predict the price trend
|
||||
if trade_index % 2 == 1:
|
||||
if trade_step % 2 == 0:
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
@@ -148,7 +195,7 @@ class SBBStrategyBase(RuleStrategy):
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
if trade_index % 2 == 1:
|
||||
if trade_step % 2 == 0:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
continue
|
||||
# get amount of one trade unit
|
||||
@@ -157,21 +204,21 @@ class SBBStrategyBase(RuleStrategy):
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# split the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
|
||||
# divide the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# split the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
|
||||
# divide the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or trade_index == trade_len
|
||||
_order_amount is None or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -190,31 +237,31 @@ class SBBStrategyBase(RuleStrategy):
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# N trade day last, split the order into N + 1 parts, and trade 2 parts
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2)
|
||||
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# N trade day last, split the order into N + 1 parts, and trade 2 parts
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_index + 1)
|
||||
// (trade_len - trade_index + 2)
|
||||
(trade_unit_cnt + trade_len - trade_step)
|
||||
// (trade_len - trade_step + 1)
|
||||
* 2
|
||||
* _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
|
||||
_order_amount is None or trade_index == trade_len
|
||||
_order_amount is None or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
if trade_index % 2 == 1:
|
||||
if trade_step % 2 == 0:
|
||||
# in the first of two adjacent bar
|
||||
# if look short on the price, sell the stock more
|
||||
# if look long on the price, sell the stock more
|
||||
@@ -253,7 +300,7 @@ class SBBStrategyBase(RuleStrategy):
|
||||
)
|
||||
order_list.append(_order)
|
||||
|
||||
if trade_index % 2 == 1:
|
||||
if trade_step % 2 == 0:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
return order_list
|
||||
@@ -269,6 +316,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
outer_trade_decision=[],
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
@@ -288,13 +336,13 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs)
|
||||
|
||||
def _reset_signal(self):
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1)
|
||||
signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
|
||||
_, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
@@ -314,8 +362,8 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "calendar" in level_infra:
|
||||
self.calendar = level_infra.get("calendar")
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
|
||||
@@ -775,7 +775,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="calendar",
|
||||
request_type="trade_calendar",
|
||||
request_content={
|
||||
"start_time": str(start_time),
|
||||
"end_time": str(end_time),
|
||||
@@ -990,7 +990,7 @@ class LocalProvider(BaseProvider):
|
||||
:param type: The type of resource for the uri
|
||||
:param **kwargs:
|
||||
"""
|
||||
if type == "calendar":
|
||||
if type == "trade_calendar":
|
||||
return Cal._uri(**kwargs)
|
||||
elif type == "instrument":
|
||||
return Inst._uri(**kwargs)
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
from ..backtest.executor import BaseExecutor
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..contrib.backtest.executor import BaseExecutor
|
||||
from ..utils import init_instance_by_config
|
||||
from .interpreter import BaseInterpreter
|
||||
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import pandas as pd
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..contrib.backtest.order import Order
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
@@ -44,8 +38,8 @@ class BaseStrategy:
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "calendar" in level_infra:
|
||||
self.calendar = level_infra.get("calendar")
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
if not hasattr(self, "common_infra"):
|
||||
@@ -83,12 +77,6 @@ class BaseStrategy:
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
"""Rule-based Trading strategy"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) is None else 1
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": "month",
|
||||
@@ -58,7 +58,8 @@ def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
"""
|
||||
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
Assumption: The fix length (240) of the calendar in each day.
|
||||
Assumption:
|
||||
- Fix length (240) of the calendar in each day.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -83,16 +84,19 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_sam_minute(x, sam_minutes):
|
||||
"""
|
||||
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
|
||||
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
|
||||
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
||||
"""
|
||||
day_time = pd.Timestamp(x.date())
|
||||
shift = C.min_data_shift
|
||||
# shift represents the shift minute the market time
|
||||
# - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
||||
# - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
||||
# - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)]
|
||||
# - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
||||
|
||||
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
|
||||
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
|
||||
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
|
||||
|
||||
if open_time <= x <= mid_close_time:
|
||||
@@ -101,7 +105,6 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
minute_index = (x - mid_open_time).seconds // 60 + 120
|
||||
else:
|
||||
raise ValueError("datetime of calendar is out of range")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
@@ -109,7 +112,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
elif 120 <= minute_index < 240:
|
||||
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
|
||||
|
||||
if freq_raw != "minute":
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
@@ -189,11 +192,13 @@ def get_resam_calendar(
|
||||
freq = "day"
|
||||
except ValueError:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "min"
|
||||
elif norm_freq == "minute":
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future)
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
|
||||
@@ -8,10 +8,10 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from ..contrib.evaluate import risk_analysis
|
||||
from ..contrib.backtest import backtest as normal_backtest
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
|
||||
Reference in New Issue
Block a user