mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
support empty benchmark
Empty benchmark could accelerate the learning process
This commit is contained in:
@@ -8,9 +8,9 @@ from .account import Account
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .order import Order
|
||||
@@ -155,6 +155,7 @@ def get_strategy_executor(
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
|
||||
@@ -75,17 +75,7 @@ class Account:
|
||||
):
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current: BasePosition = init_instance_by_config(
|
||||
{
|
||||
@@ -100,8 +90,19 @@ class Account:
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = None
|
||||
self.positions = {}
|
||||
|
||||
# in of reset ignore None values
|
||||
self.benchmark_config = benchmark_config
|
||||
self.freq = freq
|
||||
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
|
||||
@@ -512,7 +512,7 @@ class Exchange:
|
||||
def _get_factor_or_raise_erorr(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
"""Please refer to the docs of get_amount_of_trade_unit"""
|
||||
if factor is None:
|
||||
if stock_id is not None and start_time is not None and end_time is not None :
|
||||
if stock_id is not None and start_time is not None and end_time is not None:
|
||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
@@ -537,15 +537,16 @@ class Exchange:
|
||||
the end time of trading range
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
factor = self._get_factor_or_raise_erorr(factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time)
|
||||
factor = self._get_factor_or_raise_erorr(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
)
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
|
||||
def round_amount_by_trade_unit(self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def round_amount_by_trade_unit(
|
||||
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
|
||||
):
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
|
||||
@@ -555,10 +556,9 @@ class Exchange:
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
factor = self._get_factor_or_raise_erorr(factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time)
|
||||
factor = self._get_factor_or_raise_erorr(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
)
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
|
||||
@@ -80,11 +80,12 @@ class Report:
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
if benchmark_config is None:
|
||||
return None
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if benchmark is None:
|
||||
return None
|
||||
|
||||
@@ -63,9 +63,9 @@ class TWAPStrategy(BaseStrategy):
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
@@ -169,9 +169,9 @@ class SBBStrategyBase(BaseStrategy):
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
continue
|
||||
# get amount of one trade unit
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
if _pred_trend == self.TREND_MID:
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
@@ -471,9 +471,9 @@ class ACStrategy(BaseStrategy):
|
||||
|
||||
if sig_sam is None or np.isnan(sig_sam):
|
||||
# no signal, TWAP
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time)
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
@@ -494,10 +494,9 @@ class ACStrategy(BaseStrategy):
|
||||
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
|
||||
) / np.sinh(kappa * trade_len)
|
||||
_order_amount = order.amount * amount_ratio
|
||||
_order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount,
|
||||
stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time)
|
||||
_order_amount = self.trade_exchange.round_amount_by_trade_unit(
|
||||
_order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
@@ -584,8 +583,11 @@ class FileOrderStrategy(BaseStrategy):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, file: Union[IO, str, Path, pd.DataFrame],
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None, *args, **kwargs
|
||||
self,
|
||||
file: Union[IO, str, Path, pd.DataFrame],
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
|
||||
Reference in New Issue
Block a user