mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Refine backtest codes (#1120)
* Refine backtest code * Keep working * Minor * Resolve PR comments * Fix import error * Fix import error
This commit is contained in:
@@ -2,24 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .decision import BaseTradeDecision
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure
|
||||
from .decision import Order
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .position import Position
|
||||
from .utils import CommonInfrastructure
|
||||
|
||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||
|
||||
@@ -28,26 +33,34 @@ logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
limit_threshold=None,
|
||||
exchange: Union[str, dict, object, Path] = None,
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
subscribe_fields: list = [],
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
freq: str
|
||||
frequency of data.
|
||||
start_time: Union[pd.Timestamp, str]
|
||||
closed start time for backtest.
|
||||
end_time: Union[pd.Timestamp, str]
|
||||
closed end time for backtest.
|
||||
codes: list|str
|
||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
@@ -57,8 +70,6 @@ def get_exchange(
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
trade_unit : int
|
||||
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
@@ -101,10 +112,14 @@ def get_exchange(
|
||||
|
||||
|
||||
def create_account_instance(
|
||||
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
benchmark: str,
|
||||
account: Union[float, int, dict],
|
||||
pos_type: str = "Position",
|
||||
) -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
# TODO: is very strange pass benchmark_config in the account (maybe for report)
|
||||
# There should be a post-step to process the report.
|
||||
|
||||
Parameters
|
||||
@@ -132,6 +147,8 @@ def create_account_instance(
|
||||
key "cash" means initial cash.
|
||||
key "stock1" means the information of first stock with amount and price(optional).
|
||||
...
|
||||
pos_type: str
|
||||
Postion type.
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
@@ -159,15 +176,15 @@ def create_account_instance(
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy: BaseStrategy,
|
||||
executor: BaseExecutor,
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
@@ -176,7 +193,11 @@ def get_strategy_executor(
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
benchmark=benchmark,
|
||||
account=account,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
@@ -196,29 +217,31 @@ def get_strategy_executor(
|
||||
|
||||
|
||||
def backtest(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
|
||||
executor in the nested decision execution
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
strategy : Union[str, dict, BaseStrategy]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
strategy : Union[str, dict, object, Path]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more
|
||||
information.
|
||||
executor : Union[str, dict, object, Path]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
@@ -257,16 +280,16 @@ def backtest(
|
||||
|
||||
|
||||
def collect_data(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
):
|
||||
) -> Generator[object, None, None]:
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
please refer to the docs of the backtest for the explanation of the parameters
|
||||
@@ -291,7 +314,7 @@ def collect_data(
|
||||
|
||||
def format_decisions(
|
||||
decisions: List[BaseTradeDecision],
|
||||
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
|
||||
) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:
|
||||
"""
|
||||
format the decisions collected by `qlib.backtest.collect_data`
|
||||
The decisions will be organized into a tree-like structure.
|
||||
@@ -326,4 +349,4 @@ def format_decisions(
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order"]
|
||||
__all__ = ["Order", "backtest"]
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .position import BasePosition
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
@@ -34,40 +37,42 @@ class AccumulatedInfo:
|
||||
AccumulatedInfo should be shared across different levels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.rtn = 0 # accumulated return, do not consider cost
|
||||
self.cost = 0 # accumulated cost
|
||||
self.to = 0 # accumulated turnover
|
||||
def reset(self) -> None:
|
||||
self.rtn: float = 0.0 # accumulated return, do not consider cost
|
||||
self.cost: float = 0.0 # accumulated cost
|
||||
self.to: float = 0.0 # accumulated turnover
|
||||
|
||||
def add_return_value(self, value):
|
||||
def add_return_value(self, value: float) -> None:
|
||||
self.rtn += value
|
||||
|
||||
def add_cost(self, value):
|
||||
def add_cost(self, value: float) -> None:
|
||||
self.cost += value
|
||||
|
||||
def add_turnover(self, value):
|
||||
def add_turnover(self, value: float) -> None:
|
||||
self.to += value
|
||||
|
||||
@property
|
||||
def get_return(self):
|
||||
def get_return(self) -> float:
|
||||
return self.rtn
|
||||
|
||||
@property
|
||||
def get_cost(self):
|
||||
def get_cost(self) -> float:
|
||||
return self.cost
|
||||
|
||||
@property
|
||||
def get_turnover(self):
|
||||
def get_turnover(self) -> float:
|
||||
return self.to
|
||||
|
||||
|
||||
class Account:
|
||||
"""
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in
|
||||
qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is
|
||||
shared cross all the Account object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,7 +83,7 @@ class Account:
|
||||
benchmark_config: dict = {},
|
||||
pos_type: str = "Position",
|
||||
port_metr_enabled: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""the trade account of backtest.
|
||||
|
||||
Parameters
|
||||
@@ -102,7 +107,7 @@ class Account:
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||
# 1) the following variables are shared by multiple layers
|
||||
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
|
||||
self.init_cash = init_cash
|
||||
@@ -114,7 +119,7 @@ class Account:
|
||||
"position_dict": position_dict,
|
||||
},
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
},
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
@@ -123,13 +128,13 @@ class Account:
|
||||
self.hist_positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
def is_port_metr_enabled(self) -> bool:
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current_position.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
def reset_report(self, freq: str, benchmark_config: dict) -> None:
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
# NOTE:
|
||||
@@ -140,13 +145,13 @@ class Account:
|
||||
# fill stock value
|
||||
# The frequency of account may not align with the trading frequency.
|
||||
# This may result in obscure bugs when data quality is low.
|
||||
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
|
||||
if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config:
|
||||
self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
|
||||
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -155,6 +160,7 @@ class Account:
|
||||
frequency of account & report, by default None
|
||||
benchmark_config : {}, optional
|
||||
benchmark config of report, by default None
|
||||
port_metr_enabled: bool
|
||||
"""
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
@@ -165,13 +171,13 @@ class Account:
|
||||
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_hist_positions(self):
|
||||
def get_hist_positions(self) -> dict:
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self):
|
||||
def get_cash(self) -> float:
|
||||
return self.current_position.get_cash()
|
||||
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.is_port_metr_enabled():
|
||||
# update turnover
|
||||
self.accum_info.add_turnover(trade_val)
|
||||
@@ -191,13 +197,14 @@ class Account:
|
||||
profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.current_position.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first,
|
||||
# then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
@@ -212,8 +219,15 @@ class Account:
|
||||
self.current_position.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
|
||||
def update_current_position(
|
||||
self,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
trade_exchange: Exchange,
|
||||
) -> None:
|
||||
"""
|
||||
Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current_position.skip_update():
|
||||
@@ -228,7 +242,7 @@ class Account:
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
self.current_position.add_count_all(bar=self.freq)
|
||||
|
||||
def update_portfolio_metrics(self, trade_start_time, trade_end_time):
|
||||
def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None:
|
||||
"""update portfolio_metrics"""
|
||||
# calculate earning
|
||||
# account_value - last_account_value
|
||||
@@ -243,14 +257,16 @@ class Account:
|
||||
last_account_value = self.portfolio_metrics.get_latest_account_value()
|
||||
last_total_cost = self.portfolio_metrics.get_latest_total_cost()
|
||||
last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
|
||||
|
||||
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
now_stock_value = self.current_position.calculate_stock_value()
|
||||
now_earning = now_account_value - last_account_value
|
||||
now_cost = self.accum_info.get_cost - last_total_cost
|
||||
now_turnover = self.accum_info.get_turnover - last_total_turnover
|
||||
|
||||
# update portfolio_metrics for today
|
||||
# judge whether the the trading is begin.
|
||||
# judge whether the trading is begin.
|
||||
# and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
|
||||
self.portfolio_metrics.update_portfolio_metrics_record(
|
||||
trade_start_time=trade_start_time,
|
||||
@@ -267,7 +283,7 @@ class Account:
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
|
||||
def update_hist_positions(self, trade_start_time):
|
||||
def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None:
|
||||
"""update history position"""
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
# set now_account_value to position
|
||||
@@ -287,7 +303,7 @@ class Account:
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
@@ -323,7 +339,7 @@ class Account:
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update account at each trading bar step
|
||||
|
||||
Parameters
|
||||
@@ -338,6 +354,8 @@ class Account:
|
||||
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
|
||||
- if atomic is True, calculate the indicators with trade_info
|
||||
- else, aggregate indicators with inner indicators
|
||||
outer_trade_decision: BaseTradeDecision
|
||||
external trade decision
|
||||
trade_info : List[(Order, float, float, float)], optional
|
||||
trading information, by default None
|
||||
- necessary if atomic is True
|
||||
@@ -377,7 +395,7 @@ class Account:
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
def get_portfolio_metrics(self):
|
||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||
"""get the history portfolio_metrics and positions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
|
||||
@@ -2,17 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from typing import TYPE_CHECKING
|
||||
from qlib.backtest.report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.time import Freq
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
|
||||
def backtest_loop(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
@@ -31,19 +43,23 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
|
||||
):
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict = None,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
trade_strategy : BaseStrategy
|
||||
the outermost portfolio strategy
|
||||
trade_executor : BaseExecutor
|
||||
|
||||
@@ -2,23 +2,26 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from typing import ClassVar, Optional, Union, List, Tuple
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
@@ -46,7 +49,7 @@ class Order:
|
||||
# - they are set by users and is time-invariant.
|
||||
stock_id: str
|
||||
amount: float # `amount` is a non-negative and adjusted value
|
||||
direction: int
|
||||
direction: OrderDir
|
||||
|
||||
# 2) time variant values:
|
||||
# - Users may want to set these values when using lower level APIs
|
||||
@@ -61,7 +64,7 @@ class Order:
|
||||
# What the value should be about in all kinds of cases
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
|
||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
@@ -74,10 +77,10 @@ class Order:
|
||||
SELL: ClassVar[OrderDir] = OrderDir.SELL
|
||||
BUY: ClassVar[OrderDir] = OrderDir.BUY
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
self.deal_amount = 0.0
|
||||
self.factor = None
|
||||
|
||||
@property
|
||||
@@ -99,7 +102,7 @@ class Order:
|
||||
return self.deal_amount * self.sign
|
||||
|
||||
@property
|
||||
def sign(self) -> float:
|
||||
def sign(self) -> int:
|
||||
"""
|
||||
return the sign of trading
|
||||
- `+1` indicates buying
|
||||
@@ -112,15 +115,12 @@ class Order:
|
||||
if isinstance(direction, OrderDir):
|
||||
return direction
|
||||
elif isinstance(direction, (int, float, np.integer, np.floating)):
|
||||
if direction > 0:
|
||||
return Order.BUY
|
||||
else:
|
||||
return Order.SELL
|
||||
return Order.BUY if direction > 0 else Order.SELL
|
||||
elif isinstance(direction, str):
|
||||
dl = direction.lower()
|
||||
if dl.strip() == "sell":
|
||||
dl = direction.lower().strip()
|
||||
if dl == "sell":
|
||||
return OrderDir.SELL
|
||||
elif dl.strip() == "buy":
|
||||
elif dl == "buy":
|
||||
return OrderDir.BUY
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -138,14 +138,14 @@ class OrderHelper:
|
||||
Motivation
|
||||
- Make generating order easier
|
||||
- User may have no knowledge about the adjust-factor information about the system.
|
||||
- It involves to much interaction with the exchange when generating orders.
|
||||
- It involves too much interaction with the exchange when generating orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange: Exchange):
|
||||
def __init__(self, exchange: Exchange) -> None:
|
||||
self.exchange = exchange
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
self,
|
||||
code: str,
|
||||
amount: float,
|
||||
direction: OrderDir,
|
||||
@@ -175,21 +175,18 @@ class OrderHelper:
|
||||
Order:
|
||||
The created order
|
||||
"""
|
||||
if start_time is not None:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if end_time is not None:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
class TradeRange:
|
||||
@abstractmethod
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
"""
|
||||
This method will be call with following way
|
||||
@@ -216,6 +213,7 @@ class TradeRange:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__call__` method")
|
||||
|
||||
@abstractmethod
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Parameters
|
||||
@@ -234,23 +232,26 @@ class TradeRange:
|
||||
|
||||
|
||||
class IdxTradeRange(TradeRange):
|
||||
def __init__(self, start_idx: int, end_idx: int):
|
||||
def __init__(self, start_idx: int, end_idx: int) -> None:
|
||||
self._start_idx = start_idx
|
||||
self._end_idx = end_idx
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
return self._start_idx, self._end_idx
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str):
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
**NOTE**:
|
||||
- It is designed for minute-bar for intraday trading!!!!!
|
||||
- It is designed for minute-bar for intra-day trading!!!!!
|
||||
- Both start_time and end_time are **closed** in the range
|
||||
|
||||
Parameters
|
||||
@@ -264,26 +265,25 @@ class TradeRangeByTime(TradeRange):
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
if trade_calendar is None:
|
||||
raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.")
|
||||
start = trade_calendar.start_time
|
||||
val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time(
|
||||
start.date(), self.end_time
|
||||
)
|
||||
|
||||
start_date = trade_calendar.start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
return trade_calendar.get_range_idx(val_start, val_end)
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
start_date = start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
# NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day
|
||||
# Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date
|
||||
# Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date
|
||||
return max(val_start, start_time), min(val_end, end_time)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
@@ -297,7 +297,7 @@ class BaseTradeDecision:
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -339,7 +339,7 @@ class BaseTradeDecision:
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
Be called at the **start** of each step.
|
||||
|
||||
@@ -354,10 +354,8 @@ class BaseTradeDecision:
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
New update, use new decision. If no updates, return None (use previous decision (or unavailable))
|
||||
"""
|
||||
# purpose 1)
|
||||
self.total_step = trade_calendar.get_trade_len()
|
||||
@@ -412,12 +410,12 @@ class BaseTradeDecision:
|
||||
"""
|
||||
try:
|
||||
_start_idx, _end_idx = self._get_range_limit(**kwargs)
|
||||
except NotImplementedError:
|
||||
except NotImplementedError as e:
|
||||
if "default_value" in kwargs:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from e
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
@@ -426,7 +424,7 @@ class BaseTradeDecision:
|
||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||
logger = get_module_logger("decision")
|
||||
logger.warning(
|
||||
f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped"
|
||||
f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.",
|
||||
)
|
||||
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
|
||||
return _start_idx, _end_idx
|
||||
@@ -444,7 +442,7 @@ class BaseTradeDecision:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
raise_error: bool
|
||||
@@ -497,11 +495,10 @@ class BaseTradeDecision:
|
||||
return True
|
||||
return True
|
||||
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision):
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:
|
||||
"""
|
||||
|
||||
This method will be called on the inner_trade_decision after it is generated.
|
||||
`inner_trade_decision` will be changed **inplaced**.
|
||||
`inner_trade_decision` will be changed **inplace**.
|
||||
|
||||
Motivation of the `mod_inner_decision`
|
||||
- Leave a hook for outer decision to affect the decision generated by the inner strategy
|
||||
@@ -520,6 +517,9 @@ class BaseTradeDecision:
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
def get_decision(self) -> List[object]:
|
||||
return []
|
||||
|
||||
def empty(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -544,4 +544,9 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"strategy: {self.strategy}; "
|
||||
f"trade_range: {self.trade_range}; "
|
||||
f"order_list[{len(self.order_list)}]"
|
||||
)
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..utils.index_data import IndexData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..data.data import D
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
@@ -24,22 +28,22 @@ from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
subscribe_fields=[],
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold=None,
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
impact_cost: float = 0.0,
|
||||
extra_quote: pd.DataFrame = None,
|
||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""__init__
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
@@ -72,11 +76,12 @@ class Exchange:
|
||||
]
|
||||
1) ("cum" or "current", limit_str) denotes a single volume limit.
|
||||
- limit_str is qlib data expression which is allowed to define your own Operator.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency,
|
||||
such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to
|
||||
register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market volume.
|
||||
So when it is used as a volume limit, it is necessary to subtract the dealt amount.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for
|
||||
high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom
|
||||
operator, you need to register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market
|
||||
volume. So when it is used as a volume limit, it is necessary to subtract the dealt
|
||||
amount.
|
||||
- "current" means that this is a real-time value and will not accumulate over time,
|
||||
so it can be directly used as a capacity limit.
|
||||
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
|
||||
@@ -84,7 +89,7 @@ class Exchange:
|
||||
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
|
||||
Different volume limits will be aggregated with min(). If volume_threshold is only
|
||||
("cum" or "current", limit_str) instead of a dict, the volume limits are for
|
||||
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
3) e.g. "volume_threshold": {
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
@@ -104,13 +109,14 @@ class Exchange:
|
||||
Necessary fields:
|
||||
$close is for calculating the total value at end of each day.
|
||||
Optional fields:
|
||||
$volume is only necessary when we limit the trade amount or calculate PA(vwap) indicator
|
||||
$volume is only necessary when we limit the trade amount or calculate
|
||||
PA(vwap) indicator
|
||||
$vwap is only necessary when we use the $vwap price as the deal price
|
||||
$factor is for rounding to the trading unit
|
||||
limit_sell will be set to False by default(False indicates we can sell this
|
||||
target on this day).
|
||||
limit_buy will be set to False by default(False indicates we can buy this
|
||||
target on this day).
|
||||
limit_sell will be set to False by default (False indicates we can sell
|
||||
this target on this day).
|
||||
limit_buy will be set to False by default (False indicates we can buy
|
||||
this target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.freq = freq
|
||||
@@ -163,7 +169,7 @@ class Exchange:
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = necessary_fields | vol_lt_fields
|
||||
all_fields = necessary_fields | set(vol_lt_fields)
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
@@ -182,17 +188,22 @@ class Exchange:
|
||||
self.quote_cls = quote_cls
|
||||
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
|
||||
|
||||
def get_quote_from_qlib(self):
|
||||
def get_quote_from_qlib(self) -> None:
|
||||
# get stock data from qlib
|
||||
if len(self.codes) == 0:
|
||||
self.codes = D.instruments()
|
||||
self.quote_df = D.features(
|
||||
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
|
||||
self.codes,
|
||||
self.all_fields,
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
freq=self.freq,
|
||||
disk_cache=True,
|
||||
).dropna(subset=["$close"])
|
||||
self.quote_df.columns = self.all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
for attr in "buy_price", "sell_price":
|
||||
for attr in ("buy_price", "sell_price"):
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.quote_df[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
@@ -238,7 +249,7 @@ class Exchange:
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold):
|
||||
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
return self.LT_TP_EXP
|
||||
@@ -249,7 +260,7 @@ class Exchange:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self, limit_threshold):
|
||||
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
@@ -263,9 +274,10 @@ class Exchange:
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
def _get_vol_limit(self, volume_threshold):
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
"""
|
||||
preproccess the volume limit.
|
||||
preprocess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
get the volume limit list of buying and selling which is composed of all limits.
|
||||
Parameters
|
||||
@@ -295,8 +307,7 @@ class Exchange:
|
||||
volume_threshold = {"all": volume_threshold}
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key in volume_threshold:
|
||||
vol_limit = volume_threshold[key]
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
@@ -307,10 +318,19 @@ class Exchange:
|
||||
|
||||
return buy_vol_limit, sell_vol_limit, fields
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
def check_stock_limit(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
start_time: pd.Timestamp
|
||||
end_time: pd.Timestamp
|
||||
direction : int, optional
|
||||
trade direction, by default None
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
@@ -328,39 +348,42 @@ class Exchange:
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
def check_stock_suspended(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> bool:
|
||||
# is suspended
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
|
||||
def is_stock_tradable(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
|
||||
stock_id, start_time, end_time, direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (
|
||||
self.check_stock_suspended(stock_id, start_time, end_time)
|
||||
or self.check_stock_limit(stock_id, start_time, end_time, direction)
|
||||
)
|
||||
|
||||
def check_order(self, order):
|
||||
def check_order(self, order: Order) -> bool:
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time, order.direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)
|
||||
|
||||
def deal_order(
|
||||
self,
|
||||
order,
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
dealt_order_amount: defaultdict = defaultdict(float),
|
||||
):
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
the results section in `Order` will be changed.
|
||||
@@ -371,9 +394,9 @@ class Exchange:
|
||||
:return: trade_val, trade_cost, trade_price
|
||||
"""
|
||||
# check order first.
|
||||
if self.check_order(order) is False:
|
||||
if not self.check_order(order):
|
||||
order.deal_amount = 0.0
|
||||
# using np.nan instead of None to make it more convenient to should the value in format string
|
||||
# using np.nan instead of None to make it more convenient to show the value in format string
|
||||
self.logger.debug(f"Order failed due to trading limitation: {order}")
|
||||
return 0.0, 0.0, np.nan
|
||||
|
||||
@@ -382,7 +405,9 @@ class Exchange:
|
||||
|
||||
# NOTE: order will be changed in this function
|
||||
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current_position if trade_account else position, dealt_order_amount
|
||||
order,
|
||||
trade_account.current_position if trade_account else position,
|
||||
dealt_order_amount,
|
||||
)
|
||||
if trade_val > 1e-5:
|
||||
# If the order can only be deal 0 value. Nothing to be updated
|
||||
@@ -396,23 +421,49 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method)
|
||||
def get_quote_info(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
def get_close(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
def get_volume(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "sum",
|
||||
) -> float:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
def get_deal_price(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: str = "ts_data_last",
|
||||
) -> float:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
pstr = self.buy_price
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
|
||||
if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
@@ -420,11 +471,16 @@ class Exchange:
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
def get_factor(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
Optional[float]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
@@ -434,11 +490,16 @@ class Exchange:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
|
||||
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
):
|
||||
self,
|
||||
weight_position: dict,
|
||||
cash: float,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir = OrderDir.BUY,
|
||||
) -> dict:
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
NOTE: All the cash will assigned to the tradable stock.
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
@@ -451,15 +512,14 @@ class Exchange:
|
||||
|
||||
# calculate the total weight of tradable value
|
||||
tradable_weight = 0.0
|
||||
for stock_id in weight_position:
|
||||
for stock_id, wp in weight_position.items():
|
||||
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
# weight_position must be greater than 0 and less than 1
|
||||
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
|
||||
if wp < 0 or wp > 1:
|
||||
raise ValueError(
|
||||
"weight_position is {}, "
|
||||
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
|
||||
"weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp),
|
||||
)
|
||||
tradable_weight += weight_position[stock_id]
|
||||
tradable_weight += wp
|
||||
|
||||
if tradable_weight - 1.0 >= 1e-5:
|
||||
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
|
||||
@@ -467,19 +527,24 @@ class Exchange:
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount, target_amount, factor):
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -501,7 +566,13 @@ class Exchange:
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
def generate_order_for_target_amount_position(
|
||||
self,
|
||||
target_position: dict,
|
||||
current_position: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> list:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
Parameter:
|
||||
@@ -517,7 +588,8 @@ class Exchange:
|
||||
# three parts: kept stock_id, dropped stock_id, new stock_id
|
||||
# handle kept stock_id
|
||||
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest
|
||||
# results of the same parameter are different;
|
||||
# so here we sort stock_id, and then randomly shuffle the order of stock_id
|
||||
# because the same random seed is used, the final stock_id order is fixed
|
||||
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
|
||||
@@ -546,7 +618,7 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
# sell stock
|
||||
@@ -558,14 +630,19 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(
|
||||
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
|
||||
):
|
||||
self,
|
||||
amount_dict: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
only_tradable: bool = False,
|
||||
direction: OrderDir = OrderDir.SELL,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
@@ -576,21 +653,28 @@ class Exchange:
|
||||
"""
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
only_tradable is True
|
||||
and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
or only_tradable is False
|
||||
if not only_tradable or (
|
||||
not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
):
|
||||
value += (
|
||||
self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
* amount_dict[stock_id]
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def _get_factor_or_raise_error(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Please refer to the docs of get_amount_of_trade_unit"""
|
||||
if factor is None:
|
||||
if stock_id is not None and start_time is not None and end_time is not None:
|
||||
@@ -599,7 +683,13 @@ class Exchange:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def get_amount_of_trade_unit(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
get the trade unit of amount based on **factor**
|
||||
the factor can be given directly or calculated in given time range and stock id.
|
||||
@@ -617,14 +707,22 @@ class Exchange:
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
|
||||
def round_amount_by_trade_unit(
|
||||
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
|
||||
self,
|
||||
deal_amount,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
@@ -635,7 +733,10 @@ class Exchange:
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
@@ -714,7 +815,12 @@ class Exchange:
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
return max_trade_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
def _calc_trade_info_by_order(
|
||||
self,
|
||||
order: Order,
|
||||
position: Optional[BasePosition],
|
||||
dealt_order_amount: dict,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Calculation of trade info
|
||||
**NOTE**: Order will be changed in this function
|
||||
@@ -753,7 +859,8 @@ class Exchange:
|
||||
if not np.isclose(order.deal_amount, current_amount):
|
||||
# when not selling last stock. rounding is necessary
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(current_amount, order.deal_amount), order.factor
|
||||
min(current_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
|
||||
# in case of negative value of cash
|
||||
@@ -778,7 +885,8 @@ class Exchange:
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
min(max_buy_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
from abc import abstractmethod
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import GeneratorType
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from .decision import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .utils import (
|
||||
BaseInfrastructure,
|
||||
CommonInfrastructure,
|
||||
LevelInfrastructure,
|
||||
TradeCalendarManager,
|
||||
get_start_end_idx,
|
||||
)
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -30,9 +39,9 @@ class BaseExecutor:
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
settle_type=BasePosition.ST_NO,
|
||||
settle_type=BasePosition.ST_NO, # TODO: add typehint
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -53,15 +62,21 @@ class BaseExecutor:
|
||||
- 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
|
||||
- If 'base_price' is 'twap', the based price is the time weighted average price
|
||||
- If 'base_price' is 'vwap', the based price is the volume weighted average price
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' pa
|
||||
- 'ffr_config': config for calculating fulfill rate(ffr), optional
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' ffr
|
||||
Example:
|
||||
{
|
||||
'show_indicator': True,
|
||||
@@ -79,7 +94,8 @@ class BaseExecutor:
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate trade_decision, will be used when training rl agent
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will
|
||||
be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
|
||||
trade_exchange : Exchange
|
||||
@@ -114,7 +130,7 @@ class BaseExecutor:
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.deal_day = None
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
@@ -132,7 +148,7 @@ class BaseExecutor:
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
else:
|
||||
self.trade_account = common_infra.get("trade_account")
|
||||
self.trade_account: Account = common_infra.get("trade_account")
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
@@ -148,7 +164,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -161,13 +177,13 @@ class BaseExecutor:
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
def get_level_infra(self) -> LevelInfrastructure:
|
||||
return self.level_infra
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]:
|
||||
"""execute the trade decision and return the executed result
|
||||
|
||||
NOTE: this function is never used directly in the framework. Should we delete it?
|
||||
@@ -189,9 +205,15 @@ class BaseExecutor:
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Union[
|
||||
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
|
||||
Tuple[List[object], dict],
|
||||
]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
@@ -209,8 +231,11 @@ class BaseExecutor:
|
||||
"""
|
||||
|
||||
def collect_data(
|
||||
self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
|
||||
) -> List[object]:
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
@@ -253,7 +278,9 @@ class BaseExecutor:
|
||||
obj = self._collect_data(trade_decision=trade_decision, level=level)
|
||||
|
||||
if isinstance(obj, GeneratorType):
|
||||
res, kwargs = yield from obj
|
||||
yield_res = yield from obj
|
||||
assert isinstance(yield_res, tuple) and len(yield_res) == 2
|
||||
res, kwargs = yield_res
|
||||
else:
|
||||
# Some concrete executor don't have inner decisions
|
||||
res, kwargs = obj
|
||||
@@ -279,7 +306,7 @@ class BaseExecutor:
|
||||
return_value.update({"execute_result": res})
|
||||
return res
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
|
||||
@@ -287,7 +314,8 @@ class BaseExecutor:
|
||||
class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
Nested Executor with inner strategy and executor
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision`
|
||||
in a higher frequency env.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -305,7 +333,7 @@ class NestedExecutor(BaseExecutor):
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -323,10 +351,14 @@ class NestedExecutor(BaseExecutor):
|
||||
It is only for nested executor, because range_limit is given by outer strategy
|
||||
"""
|
||||
self.inner_executor: BaseExecutor = init_instance_by_config(
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
inner_executor,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseExecutor,
|
||||
)
|
||||
self.inner_strategy: BaseStrategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
|
||||
inner_strategy,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseStrategy,
|
||||
)
|
||||
|
||||
self._skip_empty_decision = skip_empty_decision
|
||||
@@ -344,10 +376,10 @@ class NestedExecutor(BaseExecutor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
- reset inner_strategy and inner_executor common infra
|
||||
"""
|
||||
# NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`
|
||||
|
||||
@@ -358,7 +390,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None:
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
@@ -368,14 +400,18 @@ class NestedExecutor(BaseExecutor):
|
||||
def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
# outer strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
if updated_trade_decision is not None: # TODO: always is None for now?
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outer decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
@@ -390,8 +426,8 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
if trade_decision.empty() and self._skip_empty_decision:
|
||||
# give one chance for outer strategy to update the strategy
|
||||
# - For updating some information in the sub executor(the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
# - For updating some information in the sub executor (the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
break
|
||||
|
||||
sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
|
||||
@@ -405,15 +441,19 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: !!!!!
|
||||
# the two lines below is for a special case in RL
|
||||
# To solve the confliction below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
|
||||
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
|
||||
# To solve the conflicts below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction
|
||||
# loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=>
|
||||
# (inner Qlib Executor)])
|
||||
# - However, RL-based framework has it's own script to run the loop
|
||||
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution
|
||||
# below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of
|
||||
# RL Framework
|
||||
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env)
|
||||
# by `yield from` and wait for the action from the policy
|
||||
# So the two lines below is the implementation of yielding control rights
|
||||
if isinstance(res, GeneratorType):
|
||||
res = yield from res
|
||||
@@ -427,13 +467,15 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
trade_decision=_inner_trade_decision,
|
||||
level=level + 1,
|
||||
)
|
||||
assert isinstance(_inner_execute_result, list)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True),
|
||||
)
|
||||
else:
|
||||
# do nothing and just step forward
|
||||
@@ -441,7 +483,7 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res):
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
"""
|
||||
A hook for doing sth after each step of inner strategy
|
||||
|
||||
@@ -451,11 +493,23 @@ class NestedExecutor(BaseExecutor):
|
||||
the execution result of inner task
|
||||
"""
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[object]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
def _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]:
|
||||
"""
|
||||
IDE-friendly helper function.
|
||||
"""
|
||||
decisions = trade_decision.get_decision()
|
||||
orders: List[Order] = []
|
||||
for decision in decisions:
|
||||
assert isinstance(decision, Order)
|
||||
orders.append(decision)
|
||||
return orders
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
"""Executor that simulate the true market"""
|
||||
|
||||
@@ -464,10 +518,10 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
# available trade_types
|
||||
TT_SERIAL = "serial"
|
||||
## The orders will be executed serially in a sequence
|
||||
# The orders will be executed serially in a sequence
|
||||
# In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
|
||||
TT_PARAL = "parallel"
|
||||
## The orders will be executed parallelly
|
||||
# The orders will be executed in parallel
|
||||
# In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
|
||||
# occur
|
||||
|
||||
@@ -483,7 +537,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -517,7 +571,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
List[Order]:
|
||||
get a list orders according to `self.trade_type`
|
||||
"""
|
||||
orders = trade_decision.get_decision()
|
||||
orders = _retrieve_orders_from_decision(trade_decision)
|
||||
|
||||
if self.trade_type == self.TT_SERIAL:
|
||||
# Orders will be traded in a parallel way
|
||||
@@ -525,15 +579,15 @@ class SimulatorExecutor(BaseExecutor):
|
||||
elif self.trade_type == self.TT_PARAL:
|
||||
# NOTE: !!!!!!!
|
||||
# Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
|
||||
# The parallel trading failure will be caused only by the confliction of money
|
||||
# Therefore, make the buying go first will make sure the confliction happen.
|
||||
# The parallel trading failure will be caused only by the conflicts of money
|
||||
# Therefore, make the buying go first will make sure the conflicts happen.
|
||||
# It equals to parallel trading after sorting the order by direction
|
||||
order_it = sorted(orders, key=lambda order: -order.direction)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def _update_dealt_order_amount(self, order):
|
||||
def _update_dealt_order_amount(self, order: Order) -> None:
|
||||
"""update date and dealt order amount in the day."""
|
||||
|
||||
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
|
||||
@@ -542,8 +596,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self.deal_day = now_deal_day
|
||||
self.dealt_order_amount[order.stock_id] += order.deal_amount
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
|
||||
@@ -559,7 +612,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self._update_dealt_order_amount(order)
|
||||
if self.verbose:
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, "
|
||||
"value {:.2f}, cash {:.2f}.".format(
|
||||
trade_start_time,
|
||||
"sell" if order.direction == Order.SELL else "buy",
|
||||
order.stock_id,
|
||||
@@ -569,6 +623,6 @@ class SimulatorExecutor(BaseExecutor):
|
||||
order.factor,
|
||||
trade_val,
|
||||
self.trade_account.get_cash(),
|
||||
)
|
||||
),
|
||||
)
|
||||
return execute_result, {"trade_info": execute_result}
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
import inspect
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, Iterable, List, Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import is_single_value, Freq
|
||||
import qlib.utils.index_data as idd
|
||||
from ..utils.time import Freq, is_single_value
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
@@ -627,7 +628,9 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
order_indicator.data[metric] = idd.sum_by_index(
|
||||
[indicator.data[metric] for indicator in indicators], stocks, fill_value
|
||||
[indicator.data[metric] for indicator in indicators],
|
||||
stocks,
|
||||
fill_value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -2,24 +2,28 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .decision import Order
|
||||
from ..data.data import D
|
||||
from .decision import Order
|
||||
|
||||
|
||||
class BasePosition:
|
||||
"""
|
||||
The Position want to maintain the position like a dictionary
|
||||
The Position wants to maintain the position like a dictionary
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, *args, cash=0.0, **kwargs):
|
||||
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
|
||||
self._settle_type = self.ST_NO
|
||||
self.position = {}
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
pass
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
"""
|
||||
@@ -49,7 +53,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -64,7 +68,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_order` method")
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
"""
|
||||
Updating the latest price of the order
|
||||
The useful when clearing balance at each bar end
|
||||
@@ -89,6 +93,9 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
|
||||
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
@@ -124,14 +131,16 @@ class BasePosition:
|
||||
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the available(tradable) cash in position
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
@@ -165,7 +174,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar) -> None:
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
@@ -176,24 +185,19 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
"""
|
||||
Updating the position weight;
|
||||
|
||||
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
|
||||
# and before updating weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
ST_CASH = "cash"
|
||||
ST_NO = None
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
"""
|
||||
settlement start
|
||||
It will act like start and commit a transaction
|
||||
@@ -210,14 +214,9 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_conf` method")
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
"""
|
||||
settlement commit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settle_type : str
|
||||
please refer to the documents of Executor
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
@@ -242,13 +241,11 @@ class Position(BasePosition):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}):
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:
|
||||
"""Init position by cash and position_dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest. It's for filling the initial value of stocks.
|
||||
cash : float, optional
|
||||
initial cash in account, by default 0
|
||||
position_dict : Dict[
|
||||
@@ -268,9 +265,9 @@ class Position(BasePosition):
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
for stock in self.position:
|
||||
if isinstance(self.position[stock], int):
|
||||
self.position[stock] = {"amount": self.position[stock]}
|
||||
for stock, value in self.position.items():
|
||||
if isinstance(value, int):
|
||||
self.position[stock] = {"amount": value}
|
||||
self.position["cash"] = cash
|
||||
|
||||
# If the stock price information is missing, the account value will not be calculated temporarily
|
||||
@@ -279,21 +276,23 @@ class Position(BasePosition):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
"""fill the stock value by the close price of latest last_days from qlib.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest.
|
||||
freq : str
|
||||
Frequency
|
||||
last_days : int, optional
|
||||
the days to get the latest close price, by default 30.
|
||||
"""
|
||||
stock_list = []
|
||||
for stock in self.position:
|
||||
if not isinstance(self.position[stock], dict):
|
||||
for stock, value in self.position.items():
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
|
||||
if value.get("price", None) is None:
|
||||
stock_list.append(stock)
|
||||
|
||||
if len(stock_list) == 0:
|
||||
@@ -304,7 +303,12 @@ class Position(BasePosition):
|
||||
price_end_time = start_time
|
||||
price_start_time = start_time - timedelta(days=last_days)
|
||||
price_df = D.features(
|
||||
stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True
|
||||
stock_list,
|
||||
["$close"],
|
||||
price_start_time,
|
||||
price_end_time,
|
||||
freq=freq,
|
||||
disk_cache=True,
|
||||
).dropna()
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
|
||||
@@ -316,7 +320,7 @@ class Position(BasePosition):
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
@@ -334,7 +338,7 @@ class Position(BasePosition):
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
@@ -344,15 +348,16 @@ class Position(BasePosition):
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
if np.isclose(self.position[stock_id]["amount"], trade_amount):
|
||||
# Selling all the stocks
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both
|
||||
# relative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
self._del_stock(stock_id)
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
@@ -361,8 +366,10 @@ class Position(BasePosition):
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(
|
||||
self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount
|
||||
)
|
||||
self.position[stock_id]["amount"] + trade_amount,
|
||||
stock_id,
|
||||
trade_amount,
|
||||
),
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
@@ -373,13 +380,13 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _del_stock(self, stock_id):
|
||||
def _del_stock(self, stock_id: str) -> None:
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
return stock_id in self.position
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
@@ -390,54 +397,54 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
def update_stock_price(self, stock_id, price):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
self.position[stock_id]["price"] = price
|
||||
|
||||
def update_stock_count(self, stock_id, bar, count):
|
||||
def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar`
|
||||
self.position[stock_id][f"count_{bar}"] = count
|
||||
|
||||
def update_stock_weight(self, stock_id, weight):
|
||||
def update_stock_weight(self, stock_id: str, weight: float) -> None:
|
||||
self.position[stock_id]["weight"] = weight
|
||||
|
||||
def calculate_stock_value(self):
|
||||
def calculate_stock_value(self) -> float:
|
||||
stock_list = self.get_stock_list()
|
||||
value = 0
|
||||
for stock_id in stock_list:
|
||||
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
|
||||
return value
|
||||
|
||||
def calculate_value(self):
|
||||
def calculate_value(self) -> float:
|
||||
value = self.calculate_stock_value()
|
||||
value += self.position["cash"] + self.position.get("cash_delay", 0.0)
|
||||
return value
|
||||
|
||||
def get_stock_list(self):
|
||||
def get_stock_list(self) -> List[str]:
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"})
|
||||
return stock_list
|
||||
|
||||
def get_stock_price(self, code):
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
return self.position[code]["price"]
|
||||
|
||||
def get_stock_amount(self, code):
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return self.position[code]["amount"] if code in self.position else 0
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
def get_stock_count(self, code: str, bar: str) -> float:
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_stock_weight(self, code):
|
||||
def get_stock_weight(self, code: str) -> float:
|
||||
return self.position[code]["weight"]
|
||||
|
||||
def get_cash(self, include_settle=False):
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
cash = self.position["cash"]
|
||||
if include_settle:
|
||||
cash += self.position.get("cash_delay", 0.0)
|
||||
return cash
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
@@ -445,7 +452,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.get_stock_amount(code=stock_code)
|
||||
return d
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=False):
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""get_stock_weight_dict
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
@@ -463,7 +470,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
|
||||
return d
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
stock_list = self.get_stock_list()
|
||||
for code in stock_list:
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
@@ -471,18 +478,18 @@ class Position(BasePosition):
|
||||
else:
|
||||
self.position[code][f"count_{bar}"] = 1
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
weight_dict = self.get_stock_weight_dict()
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def settle_start(self, settle_type):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!"
|
||||
self._settle_type = settle_type
|
||||
if settle_type == self.ST_CASH:
|
||||
self.position["cash_delay"] = 0.0
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
if self._settle_type != self.ST_NO:
|
||||
if self._settle_type == self.ST_CASH:
|
||||
self.position["cash"] += self.position["cash_delay"]
|
||||
@@ -507,10 +514,10 @@ class InfPosition(BasePosition):
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
pass
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
pass
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
@@ -522,17 +529,20 @@ class InfPosition(BasePosition):
|
||||
"""
|
||||
return np.inf
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||
|
||||
def get_stock_list(self) -> list:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""the price of the inf position is meaningless"""
|
||||
return np.nan
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_cash(self, include_settle=False) -> float:
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
@@ -541,14 +551,14 @@ class InfPosition(BasePosition):
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support add_count_all")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
pass
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -4,14 +4,16 @@
|
||||
This module is not well maintained.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .position import Position
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..config import C
|
||||
from ..data import D
|
||||
from .position import Position
|
||||
|
||||
|
||||
def get_benchmark_weight(
|
||||
bench,
|
||||
@@ -214,7 +216,9 @@ def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, g
|
||||
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
|
||||
bench_values = stock_group_field_df.loc[idx, row[row].index]
|
||||
new_stock_group_df.loc[idx] = get_daily_bin_group(
|
||||
bench_values, stock_group_field_df.loc[idx], group_n=group_n
|
||||
bench_values,
|
||||
stock_group_field_df.loc[idx],
|
||||
group_n=group_n,
|
||||
)
|
||||
return new_stock_group_df
|
||||
|
||||
@@ -315,7 +319,7 @@ def brinson_pa(
|
||||
# The excess profit from the interaction of assets allocation and stocks selection
|
||||
"RIN": Q4 - Q3 - Q2 + Q1,
|
||||
"RTotal": Q4 - Q1, # The totoal excess profit
|
||||
}
|
||||
},
|
||||
),
|
||||
{
|
||||
"port_group_ret": port_group_ret_df,
|
||||
|
||||
@@ -2,19 +2,20 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
import qlib.utils.index_data as idd
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from qlib.backtest.exchange import Exchange
|
||||
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
|
||||
|
||||
class PortfolioMetrics:
|
||||
@@ -161,7 +162,8 @@ class PortfolioMetrics:
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, "
|
||||
"total_cost, cost_rate, stock_value]",
|
||||
)
|
||||
|
||||
if trade_end_time is None and bench_value is None:
|
||||
@@ -335,7 +337,10 @@ class Indicator:
|
||||
# sum inner order indicators with same metric.
|
||||
all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"]
|
||||
self.order_indicator_cls.sum_all_indicators(
|
||||
self.order_indicator, inner_order_indicators, all_metric, fill_value=0
|
||||
self.order_indicator,
|
||||
inner_order_indicators,
|
||||
all_metric,
|
||||
fill_value=0,
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
@@ -378,12 +383,17 @@ class Indicator:
|
||||
|
||||
if decision.trade_range is not None:
|
||||
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
|
||||
start_time=trade_start_time, end_time=trade_end_time
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
|
||||
if price == "deal_price":
|
||||
price_s = trade_exchange.get_deal_price(
|
||||
inst, trade_start_time, trade_end_time, direction=direction, method=None
|
||||
inst,
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
direction=direction,
|
||||
method=None,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -599,8 +609,12 @@ class Indicator:
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
|
||||
)
|
||||
freq,
|
||||
trade_start_time,
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
),
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
import abc
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..model.base import BaseModel
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
@@ -82,7 +85,7 @@ class ModelSignal(SignalWCache):
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
|
||||
@@ -2,16 +2,22 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.utils.time import epsilon_change
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
|
||||
@@ -26,8 +32,8 @@ class TradeCalendarManager:
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: "LevelInfrastructure" = None,
|
||||
):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -43,19 +49,26 @@ class TradeCalendarManager:
|
||||
self.level_infra = level_infra
|
||||
self.reset(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def reset(self, freq, start_time, end_time):
|
||||
def reset(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Please refer to the docs of `__init__`
|
||||
|
||||
Reset the trade calendar
|
||||
- self.trade_len : The total count for trading step
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be
|
||||
[0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar = Cal.calendar(freq=freq, future=True)
|
||||
assert isinstance(_calendar, np.ndarray)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
|
||||
self.start_index = _start_index
|
||||
@@ -63,7 +76,7 @@ class TradeCalendarManager:
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_step = 0
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
"""
|
||||
Check if the trading finished
|
||||
- Should check before calling strategy.generate_decisions and executor.execute
|
||||
@@ -72,29 +85,32 @@ class TradeCalendarManager:
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
def step(self) -> None:
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_step = self.trade_step + 1
|
||||
self.trade_step += 1
|
||||
|
||||
def get_freq(self):
|
||||
def get_freq(self) -> str:
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
def get_trade_len(self) -> int:
|
||||
"""get the total step length"""
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_step(self):
|
||||
def get_trade_step(self) -> int:
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step=None, shift=0):
|
||||
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
About the endpoints:
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as
|
||||
pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in
|
||||
# Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time
|
||||
# interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -105,15 +121,14 @@ class TradeCalendarManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[pd.Timestamp, pd.Timestap]
|
||||
Tuple[pd.Timestamp, pd.Timestamp]
|
||||
- If shift == 0, return the trading time range
|
||||
- If shift > 0, return the trading time range of the earlier shift bars
|
||||
- If shift < 0, return the trading time range of the later shift bar
|
||||
"""
|
||||
if trade_step is None:
|
||||
trade_step = self.get_trade_step()
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
calendar_index = self.start_index + trade_step - shift
|
||||
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
|
||||
|
||||
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
@@ -126,7 +141,7 @@ class TradeCalendarManager:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
Returns
|
||||
@@ -148,7 +163,7 @@ class TradeCalendarManager:
|
||||
|
||||
return start_idx - day_start_idx, end_index - day_start_idx
|
||||
|
||||
def get_all_time(self):
|
||||
def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
|
||||
@@ -167,30 +182,33 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left, right = (
|
||||
bisect.bisect_right(self._calendar, start_time) - 1,
|
||||
bisect.bisect_right(self._calendar, end_time) - 1,
|
||||
)
|
||||
left = bisect.bisect_right(self._calendar, start_time) - 1
|
||||
right = bisect.bisect_right(self._calendar, end_time) - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
def clip(idx):
|
||||
def clip(idx: int) -> int:
|
||||
return min(max(0, idx), self.trade_len - 1)
|
||||
|
||||
return clip(left), clip(right)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: "
|
||||
f"[{self.trade_step}/{self.trade_len}]"
|
||||
)
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
def get_support_infra(self):
|
||||
@abstractmethod
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs):
|
||||
def reset_infra(self, **kwargs) -> None:
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
@@ -198,53 +216,58 @@ class BaseInfrastructure:
|
||||
else:
|
||||
warnings.warn(f"{k} is ignored in `reset_infra`!")
|
||||
|
||||
def get(self, infra_name):
|
||||
def get(self, infra_name: str) -> Any:
|
||||
if hasattr(self, infra_name):
|
||||
return getattr(self, infra_name)
|
||||
else:
|
||||
warnings.warn(f"infra {infra_name} is not found!")
|
||||
|
||||
def has(self, infra_name):
|
||||
def has(self, infra_name: str) -> bool:
|
||||
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
|
||||
|
||||
def update(self, other):
|
||||
def update(self, other: BaseInfrastructure) -> None:
|
||||
support_infra = other.get_support_infra()
|
||||
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
|
||||
self.reset_infra(**infra_dict)
|
||||
|
||||
|
||||
class CommonInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_account", "trade_exchange"]
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
return {"trade_account", "trade_exchange"}
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
"""
|
||||
Descriptions about the infrastructure
|
||||
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return ["trade_calendar", "sub_level_infra", "common_infra"]
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra"}
|
||||
|
||||
def reset_cal(self, freq, start_time, end_time):
|
||||
def reset_cal(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp, None],
|
||||
end_time: Union[str, pd.Timestamp, None],
|
||||
) -> None:
|
||||
"""reset trade calendar manager"""
|
||||
if self.has("trade_calendar"):
|
||||
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
self.reset_infra(
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self),
|
||||
)
|
||||
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
|
||||
"""this will make the calendar access easier when acrossing multi-levels"""
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None:
|
||||
"""this will make the calendar access easier when crossing multi-levels"""
|
||||
self.reset_infra(sub_level_infra=sub_level_infra)
|
||||
|
||||
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
@@ -25,12 +28,13 @@ class BaseStrategy:
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_exchange: Exchange = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
the trade decision of outer strategy which this strategy relies, and it will be traded in [start_time, end_time], by default None
|
||||
the trade decision of outer strategy which this strategy relies, and it will be traded in
|
||||
[start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decision, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : LevelInfrastructure, optional
|
||||
@@ -41,9 +45,10 @@ class BaseStrategy:
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- It allows different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is
|
||||
recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
|
||||
@@ -63,13 +68,13 @@ class BaseStrategy:
|
||||
"""get trade exchange in a prioritized order"""
|
||||
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
|
||||
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure) -> None:
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure) -> None:
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra: CommonInfrastructure = common_infra
|
||||
else:
|
||||
@@ -79,9 +84,9 @@ class BaseStrategy:
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
**kwargs,
|
||||
):
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
**kwargs, # TODO: remove this?
|
||||
) -> None:
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -89,18 +94,20 @@ class BaseStrategy:
|
||||
|
||||
**NOTE**:
|
||||
split this function into `reset` and `_reset` will make following cases more convenient
|
||||
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
|
||||
when initialization
|
||||
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset`
|
||||
called when initialization
|
||||
"""
|
||||
self._reset(
|
||||
level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
|
||||
level_infra=level_infra,
|
||||
common_infra=common_infra,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
)
|
||||
|
||||
def _reset(
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
):
|
||||
"""
|
||||
Please refer to the docs of `reset`
|
||||
@@ -114,7 +121,8 @@ class BaseStrategy:
|
||||
if outer_trade_decision is not None:
|
||||
self.outer_trade_decision = outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
@abstractmethod
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
"""Generate trade decision in each trading bar
|
||||
|
||||
Parameters
|
||||
@@ -125,9 +133,11 @@ class BaseStrategy:
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def update_trade_decision(
|
||||
self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
|
||||
) -> Union[BaseTradeDecision, None]:
|
||||
trade_decision: BaseTradeDecision,
|
||||
trade_calendar: TradeCalendarManager,
|
||||
) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
@@ -145,7 +155,8 @@ class BaseStrategy:
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
|
||||
# FIXME: do not define this method as an abstract one since it is never implemented
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
"""
|
||||
A method for updating the outer_trade_decision.
|
||||
The outer strategy may change its decision during updating.
|
||||
@@ -154,6 +165,10 @@ class BaseStrategy:
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
"""
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
@@ -200,7 +215,7 @@ class RLStrategy(BaseStrategy):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -223,7 +238,7 @@ class RLIntStrategy(RLStrategy):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -242,7 +257,7 @@ class RLIntStrategy(RLStrategy):
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
|
||||
_action = self.policy.step(_interpret_state)
|
||||
_trade_decision = self.action_interpreter.interpret(action=_action)
|
||||
|
||||
@@ -376,7 +376,7 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object, Path],
|
||||
config: Union[str, dict, object, Path], # TODO: use a user-defined type to replace this Union.
|
||||
default_module=None,
|
||||
accept_types: Union[type, Tuple[type]] = (),
|
||||
try_kwargs: Dict = {},
|
||||
@@ -1063,4 +1063,5 @@ __all__ = [
|
||||
"unpack_archive_with_buffer",
|
||||
"get_tmp_file_with_buffer",
|
||||
"set_log_with_config",
|
||||
"init_instance_by_config",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
from qlib.backtest import backtest, decision
|
||||
from qlib.backtest import backtest
|
||||
from qlib.tests import TestAutoData
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -52,13 +52,12 @@ class FileStrTest(TestAutoData):
|
||||
factor = df["$factor"].item()
|
||||
price_unit = price / factor * 100
|
||||
dealt_num_for_1000 = (account_money // price_unit) * (100 / factor)
|
||||
print(price, factor, price_unit, dealt_num_for_1000)
|
||||
|
||||
# 2) generate orders
|
||||
orders = self._gen_orders(dealt_num_for_1000)
|
||||
print(orders)
|
||||
orders.to_csv(self.EXAMPLE_FILE)
|
||||
|
||||
orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
|
||||
print(orders)
|
||||
|
||||
# 3) run the strategy
|
||||
strategy_config = {
|
||||
@@ -101,7 +100,11 @@ class FileStrTest(TestAutoData):
|
||||
},
|
||||
},
|
||||
}
|
||||
report_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
report_dict, indicator_dict = backtest(
|
||||
executor=executor_config,
|
||||
strategy=strategy_config,
|
||||
**backtest_config,
|
||||
)
|
||||
|
||||
# ffr valid
|
||||
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
|
||||
|
||||
Reference in New Issue
Block a user