mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Backtest Mypy (#1130)
* Done * Fix test errors * Revert profit_attribution.py * Minor * A minor update on collect_data type hint * Resolve PR comments * Use black to format code * Fix CI errors
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest
|
||||
^qlib/backtest/high_performance_ds\.py$
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -23,7 +23,6 @@ from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .position import Position
|
||||
from .utils import CommonInfrastructure
|
||||
|
||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||
@@ -44,7 +43,7 @@ def get_exchange(
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
||||
@@ -52,14 +51,15 @@ def get_exchange(
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
exchange: Exchange
|
||||
It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
freq: str
|
||||
frequency of data.
|
||||
start_time: Union[pd.Timestamp, str]
|
||||
closed start time for backtest.
|
||||
end_time: Union[pd.Timestamp, str]
|
||||
closed end time for backtest.
|
||||
codes: list|str
|
||||
codes: Union[list, str]
|
||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
@@ -151,28 +151,24 @@ def create_account_instance(
|
||||
Postion type.
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
init_cash = account
|
||||
position_dict = {}
|
||||
elif isinstance(account, dict):
|
||||
init_cash = account["cash"]
|
||||
del account["cash"]
|
||||
pos_kwargs = {
|
||||
"init_cash": init_cash,
|
||||
"position_dict": account,
|
||||
}
|
||||
init_cash = account.pop("cash")
|
||||
position_dict = account
|
||||
else:
|
||||
raise ValueError("account must be in (int, float, Position)")
|
||||
raise ValueError("account must be in (int, float, dict)")
|
||||
|
||||
kwargs = {
|
||||
"init_cash": account,
|
||||
"benchmark_config": {
|
||||
return Account(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
"pos_type": pos_type,
|
||||
}
|
||||
kwargs.update(pos_kwargs)
|
||||
return Account(**kwargs)
|
||||
)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
@@ -181,7 +177,7 @@ def get_strategy_executor(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
@@ -222,7 +218,7 @@ def backtest(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
@@ -285,7 +281,7 @@ def collect_data(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
@@ -339,7 +335,7 @@ def format_decisions(
|
||||
|
||||
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
||||
|
||||
res = (cur_freq, [])
|
||||
res: Tuple[str, list] = (cur_freq, [])
|
||||
last_dec_idx = 0
|
||||
for i, dec in enumerate(decisions[1:], 1):
|
||||
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
|
||||
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .high_performance_ds import BaseOrderIndicator
|
||||
from .position import BasePosition
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
@@ -104,7 +105,7 @@ class Account:
|
||||
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.benchmark_config: dict = {} # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||
@@ -124,8 +125,8 @@ class Account:
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
# 2) following variables are not shared between layers
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.portfolio_metrics: Optional[PortfolioMetrics] = None
|
||||
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self) -> bool:
|
||||
@@ -171,7 +172,7 @@ class Account:
|
||||
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_hist_positions(self) -> dict:
|
||||
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self) -> float:
|
||||
@@ -230,13 +231,15 @@ class Account:
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
assert self.current_position is not None
|
||||
|
||||
if not self.current_position.skip_update():
|
||||
stock_list = self.current_position.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
|
||||
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
@@ -249,6 +252,8 @@ class Account:
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, last_total_cost, last_total_turnover
|
||||
assert self.portfolio_metrics is not None
|
||||
|
||||
if self.portfolio_metrics.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
last_total_cost = 0
|
||||
@@ -299,9 +304,9 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
@@ -335,9 +340,9 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
"""update account at each trading bar step
|
||||
@@ -398,6 +403,7 @@ class Account:
|
||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||
"""get the history portfolio_metrics and positions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
assert self.portfolio_metrics is not None
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
_positions = self.get_hist_positions()
|
||||
return _portfolio_metrics, _positions
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -36,10 +36,13 @@ def backtest_loop(
|
||||
indicator: Indicator
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("portfolio_metrics"), return_value.get("indicator")
|
||||
|
||||
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||
indicator = cast(Indicator, return_value.get("indicator"))
|
||||
return portfolio_metrics, indicator
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
|
||||
@@ -7,7 +7,7 @@ from abc import abstractmethod
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
@@ -24,8 +24,11 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DecisionType = TypeVar("DecisionType")
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
# Order direction
|
||||
# Order direction
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
@@ -65,7 +68,7 @@ class Order:
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
# TODO:
|
||||
@@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange):
|
||||
return max(val_start, start_time), min(val_end, end_time)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
class BaseTradeDecision(Generic[DecisionType]):
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
|
||||
@@ -316,20 +319,21 @@ class BaseTradeDecision:
|
||||
"""
|
||||
self.strategy = strategy
|
||||
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
if isinstance(trade_range, Tuple):
|
||||
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
self.total_step: Optional[int] = None
|
||||
if isinstance(trade_range, tuple):
|
||||
# for Tuple[int, int]
|
||||
trade_range = IdxTradeRange(*trade_range)
|
||||
self.trade_range: TradeRange = trade_range
|
||||
self.trade_range: Optional[TradeRange] = trade_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[DecisionType]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
List[DecisionType:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
@@ -363,13 +367,13 @@ class BaseTradeDecision:
|
||||
# purpose 2)
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
if self.trade_range is not None:
|
||||
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
|
||||
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
|
||||
else:
|
||||
raise NotImplementedError("The decision didn't provide an index range")
|
||||
|
||||
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
@@ -421,6 +425,7 @@ class BaseTradeDecision:
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
# if `self.update` is called.
|
||||
# Then the _start_idx, _end_idx should be clipped
|
||||
assert self.total_step is not None
|
||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||
logger = get_module_logger("decision")
|
||||
logger.warning(
|
||||
@@ -516,7 +521,7 @@ class BaseTradeDecision:
|
||||
inner_trade_decision.trade_range = self.trade_range
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
class EmptyTradeDecision(BaseTradeDecision[object]):
|
||||
def get_decision(self) -> List[object]:
|
||||
return []
|
||||
|
||||
@@ -524,23 +529,24 @@ class EmptyTradeDecision(BaseTradeDecision):
|
||||
return True
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
||||
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
for o in order_list:
|
||||
assert isinstance(o, Order)
|
||||
if o.start_time is None:
|
||||
o.start_time = start
|
||||
if o.end_time is None:
|
||||
o.end_time = end
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[Order]:
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from ..utils.index_data import IndexData
|
||||
|
||||
@@ -42,7 +42,7 @@ class Exchange:
|
||||
impact_cost: float = 0.0,
|
||||
extra_quote: pd.DataFrame = None,
|
||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""__init__
|
||||
:param freq: frequency of data
|
||||
@@ -141,7 +141,7 @@ class Exchange:
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
|
||||
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
@@ -150,7 +150,7 @@ class Exchange:
|
||||
deal_price = "$" + deal_price
|
||||
self.buy_price = self.sell_price = deal_price
|
||||
elif isinstance(deal_price, (tuple, list)):
|
||||
self.buy_price, self.sell_price = deal_price
|
||||
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -167,10 +167,10 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = necessary_fields | set(vol_lt_fields)
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
@@ -249,9 +249,9 @@ class Exchange:
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
|
||||
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
@@ -268,14 +268,16 @@ class Exchange:
|
||||
self.quote_df["limit_sell"] = False
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
limit_threshold = cast(tuple, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||
elif limit_type == self.LT_FLT:
|
||||
limit_threshold = cast(float, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
"""
|
||||
preprocess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
@@ -340,11 +342,11 @@ class Exchange:
|
||||
if direction is None:
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return buy_limit or sell_limit
|
||||
return bool(buy_limit or sell_limit)
|
||||
elif direction == Order.BUY:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
|
||||
elif direction == Order.SELL:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
@@ -382,7 +384,7 @@ class Exchange:
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
dealt_order_amount: defaultdict = defaultdict(float),
|
||||
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
@@ -426,9 +428,10 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
field: str,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
|
||||
|
||||
def get_close(
|
||||
self,
|
||||
@@ -444,10 +447,10 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "sum",
|
||||
method: Optional[str] = "sum",
|
||||
) -> float:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
|
||||
|
||||
def get_deal_price(
|
||||
self,
|
||||
@@ -455,7 +458,7 @@ class Exchange:
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: str = "ts_data_last",
|
||||
method: Optional[str] = "ts_data_last",
|
||||
) -> float:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
@@ -469,7 +472,7 @@ class Exchange:
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return deal_price
|
||||
return cast(float, deal_price)
|
||||
|
||||
def get_factor(
|
||||
self,
|
||||
@@ -544,7 +547,7 @@ class Exchange:
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -572,7 +575,7 @@ class Exchange:
|
||||
current_position: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> list:
|
||||
) -> List[Order]:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
Parameter:
|
||||
@@ -681,6 +684,7 @@ class Exchange:
|
||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
assert factor is not None
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(
|
||||
@@ -718,12 +722,12 @@ class Exchange:
|
||||
|
||||
def round_amount_by_trade_unit(
|
||||
self,
|
||||
deal_amount,
|
||||
deal_amount: float,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
deal_amount : float, adjusted amount
|
||||
@@ -741,7 +745,7 @@ class Exchange:
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
|
||||
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
||||
NOTE:
|
||||
this function will change the order.deal_amount **inplace**
|
||||
@@ -753,15 +757,12 @@ class Exchange:
|
||||
dealt_order_amount : dict
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
"""
|
||||
if order.direction == Order.BUY:
|
||||
vol_limit = self.buy_vol_limit
|
||||
elif order.direction == Order.SELL:
|
||||
vol_limit = self.sell_vol_limit
|
||||
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
|
||||
|
||||
if vol_limit is None:
|
||||
return order.deal_amount
|
||||
|
||||
vol_limit_num = []
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
if limit[0] == "current":
|
||||
@@ -772,7 +773,7 @@ class Exchange:
|
||||
field=limit[1],
|
||||
method="sum",
|
||||
)
|
||||
vol_limit_num.append(limit_value)
|
||||
vol_limit_num.append(cast(float, limit_value))
|
||||
elif limit[0] == "cum":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
@@ -790,12 +791,14 @@ class Exchange:
|
||||
if vol_limit_min < orig_deal_amount:
|
||||
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
return None
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cash : float
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
@@ -803,7 +806,7 @@ class Exchange:
|
||||
float
|
||||
the real order amount after cash limit for buying.
|
||||
"""
|
||||
max_trade_amount = 0
|
||||
max_trade_amount = 0.0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
@@ -897,7 +900,7 @@ class Exchange:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
raise NotImplementedError("order direction {} error".format(order.direction))
|
||||
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
||||
|
||||
@@ -4,7 +4,7 @@ import copy
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import GeneratorType
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .utils import (
|
||||
BaseInfrastructure,
|
||||
CommonInfrastructure,
|
||||
LevelInfrastructure,
|
||||
TradeCalendarManager,
|
||||
get_start_end_idx,
|
||||
)
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -39,8 +33,8 @@ class BaseExecutor:
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
settle_type=BasePosition.ST_NO, # TODO: add typehint
|
||||
**kwargs,
|
||||
settle_type: str = BasePosition.ST_NO,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -127,10 +121,10 @@ class BaseExecutor:
|
||||
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
||||
|
||||
# record deal order amount in one day
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
|
||||
self.deal_day = None
|
||||
|
||||
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
@@ -141,14 +135,15 @@ class BaseExecutor:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
if copy_trade_account:
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
else:
|
||||
self.trade_account: Account = common_infra.get("trade_account")
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = (
|
||||
copy.copy(common_infra.get("trade_account"))
|
||||
if copy_trade_account
|
||||
else common_infra.get("trade_account")
|
||||
)
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
@@ -164,7 +159,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -200,20 +195,17 @@ class BaseExecutor:
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
return cast(list, return_value.get("execute_result"))
|
||||
|
||||
@abstractmethod
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Union[
|
||||
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
|
||||
Tuple[List[object], dict],
|
||||
]:
|
||||
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
@@ -235,7 +227,7 @@ class BaseExecutor:
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
|
||||
) -> Generator[Any, Any, List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
@@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor):
|
||||
skip_empty_decision: bool = True,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
|
||||
) -> Generator[Any, Any, Tuple[List[object], dict]]:
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
@@ -493,7 +485,7 @@ class NestedExecutor(BaseExecutor):
|
||||
the execution result of inner task
|
||||
"""
|
||||
|
||||
def get_all_executors(self) -> List[object]:
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
@@ -536,7 +528,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -598,7 +590,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
execute_result: list = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
# execute the order.
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, Iterable, List, Text, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
def get_all_stock(self) -> Iterable:
|
||||
@@ -39,7 +41,7 @@ class BaseQuote:
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
field: Union[str],
|
||||
method: Union[str, None] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the specific field of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
@@ -99,7 +101,7 @@ class BaseQuote:
|
||||
|
||||
|
||||
class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
@@ -124,7 +126,7 @@ class PandasQuote(BaseQuote):
|
||||
|
||||
|
||||
class NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
|
||||
"""NumpyQuote
|
||||
|
||||
Parameters
|
||||
@@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote):
|
||||
data = self._agg_data(data, method)
|
||||
return data
|
||||
|
||||
def _agg_data(self, data: IndexData, method):
|
||||
@staticmethod
|
||||
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
|
||||
"""Agg data by specific method."""
|
||||
# FIXME: why not call the method of data directly?
|
||||
if method == "sum":
|
||||
@@ -224,31 +227,31 @@ class BaseSingleMetric:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__init__` method")
|
||||
|
||||
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__add__` method")
|
||||
|
||||
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
return self + other
|
||||
|
||||
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__sub__` method")
|
||||
|
||||
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__rsub__` method")
|
||||
|
||||
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__mul__` method")
|
||||
|
||||
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__truediv__` method")
|
||||
|
||||
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __eq__(self, other: object) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__eq__` method")
|
||||
|
||||
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__gt__` method")
|
||||
|
||||
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__lt__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -265,7 +268,7 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `count` method")
|
||||
|
||||
def abs(self) -> "BaseSingleMetric":
|
||||
def abs(self) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `abs` method")
|
||||
|
||||
@property
|
||||
@@ -274,18 +277,18 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
|
||||
def replace(self, replace_dict: dict) -> BaseSingleMetric:
|
||||
"""Replace the value of metric according to replace_dict."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `replace` method")
|
||||
|
||||
def apply(self, func: dict) -> "BaseSingleMetric":
|
||||
"""Replace the value of metric with func(metric).
|
||||
def apply(self, func: Callable) -> BaseSingleMetric:
|
||||
"""Replace the value of metric with func (metric).
|
||||
Currently, the func is only qlib/backtest/order/Order.parse_dir.
|
||||
"""
|
||||
|
||||
@@ -304,11 +307,11 @@ class BaseOrderIndicator:
|
||||
to inherit the BaseSingleMetric.
|
||||
"""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
def __init__(self):
|
||||
self.data = {} # will be created in the subclass
|
||||
self.logger = get_module_logger("online operator")
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
"""assign one metric.
|
||||
|
||||
Parameters
|
||||
@@ -328,7 +331,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'assign' method")
|
||||
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
|
||||
"""compute new metric with existing metrics.
|
||||
|
||||
Parameters
|
||||
@@ -352,6 +355,7 @@ class BaseOrderIndicator:
|
||||
tmp_metric = func(**func_kwargs)
|
||||
if new_col is not None:
|
||||
self.data[new_col] = tmp_metric
|
||||
return None
|
||||
else:
|
||||
return tmp_metric
|
||||
|
||||
@@ -372,7 +376,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
|
||||
|
||||
def get_index_data(self, metric) -> SingleData:
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
"""get one metric with the format of SingleData
|
||||
|
||||
Parameters
|
||||
@@ -389,7 +393,12 @@ class BaseOrderIndicator:
|
||||
raise NotImplementedError(f"Please implement the 'get_index_data' method")
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
"""sum indicators with the same metrics.
|
||||
and assign to the order_indicator(BaseOrderIndicator).
|
||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||
@@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
|
||||
def index(self):
|
||||
return list(self.metric.index)
|
||||
|
||||
def add(self, other, fill_value=None):
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
|
||||
other = cast(PandasSingleMetric, other)
|
||||
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
|
||||
|
||||
def replace(self, replace_dict: dict):
|
||||
def replace(self, replace_dict: dict) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.replace(replace_dict))
|
||||
|
||||
def apply(self, func: Callable):
|
||||
def apply(self, func: Callable) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.apply(func))
|
||||
|
||||
def reindex(self, index, fill_value):
|
||||
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
||||
|
||||
def __repr__(self):
|
||||
@@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(PandasOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
self.data[col] = PandasSingleMetric(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return idd.SingleData(self.data[metric].metric)
|
||||
else:
|
||||
@@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
return {k: v.metric for k, v in self.data.items()}
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
@@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(NumpyOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, SingleData] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: dict):
|
||||
def assign(self, col: str, metric: dict) -> None:
|
||||
self.data[col] = idd.SingleData(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return self.data[metric]
|
||||
else:
|
||||
@@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
return tmp_metric_dict
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
# get all index(stock_id)
|
||||
stocks = set()
|
||||
stock_set: set = set()
|
||||
for indicator in indicators:
|
||||
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
||||
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = list(stocks)
|
||||
stocks.sort()
|
||||
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = sorted(list(stock_set))
|
||||
|
||||
# add metric by index
|
||||
if isinstance(metrics, str):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -18,9 +18,9 @@ class BasePosition:
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
|
||||
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
|
||||
self._settle_type = self.ST_NO
|
||||
self.position = {}
|
||||
self.position: dict = {}
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
pass
|
||||
@@ -96,13 +96,13 @@ class BasePosition:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def get_stock_list(self) -> List[str]:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""
|
||||
get the latest price of the stock
|
||||
|
||||
@@ -113,7 +113,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
"""
|
||||
get the amount of the stock
|
||||
|
||||
@@ -144,7 +144,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""
|
||||
generate stock amount dict {stock_id : amount of stock}
|
||||
|
||||
@@ -155,7 +155,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade step
|
||||
@@ -174,7 +174,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar) -> None:
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
@@ -195,7 +195,7 @@ class BasePosition:
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
ST_CASH = "cash"
|
||||
ST_NO = None
|
||||
ST_NO = "None" # String is more typehint friendly than None
|
||||
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
"""
|
||||
@@ -220,10 +220,10 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.__dict__.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
|
||||
@@ -532,7 +532,7 @@ class InfPosition(BasePosition):
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||
|
||||
def get_stock_list(self) -> list:
|
||||
def get_stock_list(self) -> List[str]:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
@@ -545,10 +545,10 @@ class InfPosition(BasePosition):
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange
|
||||
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
|
||||
|
||||
|
||||
class PortfolioMetrics:
|
||||
@@ -38,7 +38,7 @@ class PortfolioMetrics:
|
||||
update report
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -49,13 +49,17 @@ class PortfolioMetrics:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
print(
|
||||
D.features(D.instruments('csi500'),
|
||||
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
|
||||
)
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
|
||||
'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
@@ -70,25 +74,26 @@ class PortfolioMetrics:
|
||||
self.init_vars()
|
||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account position value for each trade time
|
||||
self.returns = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers = OrderedDict() # turnover for each trade time
|
||||
self.total_costs = OrderedDict() # total trade cost for each trade time
|
||||
self.costs = OrderedDict() # trade cost rate for each trade time
|
||||
self.values = OrderedDict() # value for each trade time
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_pm_time = None # pd.TimeStamp
|
||||
def init_vars(self) -> None:
|
||||
self.accounts: dict = OrderedDict() # account position value for each trade time
|
||||
self.returns: dict = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers: dict = OrderedDict() # turnover for each trade time
|
||||
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
|
||||
self.costs: dict = OrderedDict() # trade cost rate for each trade time
|
||||
self.values: dict = OrderedDict() # value for each trade time
|
||||
self.cashes: dict = OrderedDict()
|
||||
self.benches: dict = OrderedDict()
|
||||
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
@staticmethod
|
||||
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
|
||||
if benchmark_config is None:
|
||||
return None
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
@@ -110,7 +115,12 @@ class PortfolioMetrics:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def _sample_benchmark(
|
||||
self,
|
||||
bench: pd.Series,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
trade_end_time: Union[str, pd.Timestamp],
|
||||
) -> Optional[float]:
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
@@ -120,35 +130,35 @@ class PortfolioMetrics:
|
||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0.0 if _ret is None else _ret - 1
|
||||
|
||||
def is_empty(self):
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
def get_latest_date(self) -> pd.Timestamp:
|
||||
return self.latest_pm_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
def get_latest_account_value(self) -> float:
|
||||
return self.accounts[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_cost(self):
|
||||
def get_latest_total_cost(self) -> Any:
|
||||
return self.total_costs[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_turnover(self):
|
||||
def get_latest_total_turnover(self) -> Any:
|
||||
return self.total_turnovers[self.latest_pm_time]
|
||||
|
||||
def update_portfolio_metrics_record(
|
||||
self,
|
||||
trade_start_time=None,
|
||||
trade_end_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
total_turnover=None,
|
||||
turnover_rate=None,
|
||||
total_cost=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||
account_value: float = None,
|
||||
cash: float = None,
|
||||
return_rate: float = None,
|
||||
total_turnover: float = None,
|
||||
turnover_rate: float = None,
|
||||
total_cost: float = None,
|
||||
cost_rate: float = None,
|
||||
stock_value: float = None,
|
||||
bench_value: float = None,
|
||||
) -> None:
|
||||
# check data
|
||||
if None in [
|
||||
trade_start_time,
|
||||
@@ -185,7 +195,7 @@ class PortfolioMetrics:
|
||||
self.latest_pm_time = trade_start_time
|
||||
# finish pm update in each step
|
||||
|
||||
def generate_portfolio_metrics_dataframe(self):
|
||||
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
|
||||
pm = pd.DataFrame()
|
||||
pm["account"] = pd.Series(self.accounts)
|
||||
pm["return"] = pd.Series(self.returns)
|
||||
@@ -199,19 +209,18 @@ class PortfolioMetrics:
|
||||
pm.index.name = "datetime"
|
||||
return pm
|
||||
|
||||
def save_portfolio_metrics(self, path):
|
||||
def save_portfolio_metrics(self, path: str) -> None:
|
||||
r = self.generate_portfolio_metrics_dataframe()
|
||||
r.to_csv(path)
|
||||
|
||||
def load_portfolio_metrics(self, path):
|
||||
def load_portfolio_metrics(self, path: str) -> None:
|
||||
"""load pm from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
with path.open("rb") as f:
|
||||
with pathlib.Path(path).open("rb") as f:
|
||||
r = pd.read_csv(f, index_col=0)
|
||||
r.index = pd.DatetimeIndex(r.index)
|
||||
|
||||
@@ -261,30 +270,30 @@ class Indicator:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
|
||||
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
|
||||
self.order_indicator_cls = order_indicator_cls
|
||||
|
||||
# order indicator is metrics for a single order for a specific step
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator_his: dict = OrderedDict()
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
|
||||
# trade indicator is metrics for all orders for a specific step
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator: Dict[str, float] = OrderedDict()
|
||||
self.trade_indicator_his: dict = OrderedDict()
|
||||
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
|
||||
|
||||
self._trade_calendar = None
|
||||
|
||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||
def reset(self):
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
def reset(self) -> None:
|
||||
self.order_indicator = self.order_indicator_cls()
|
||||
self.trade_indicator = OrderedDict()
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
def record(self, trade_start_time):
|
||||
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
|
||||
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
||||
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
||||
|
||||
def _update_order_trade_info(self, trade_info: list):
|
||||
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
amount = dict()
|
||||
deal_amount = dict()
|
||||
trade_price = dict()
|
||||
@@ -313,7 +322,7 @@ class Indicator:
|
||||
self.order_indicator.assign("trade_dir", trade_dir)
|
||||
self.order_indicator.assign("pa", pa)
|
||||
|
||||
def _update_order_fulfill_rate(self):
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
@@ -322,11 +331,11 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
|
||||
def update_order_indicators(self, trade_info: list):
|
||||
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
|
||||
# calculate total trade amount with each inner order indicator.
|
||||
def trade_amount_func(deal_amount, trade_price):
|
||||
return deal_amount * trade_price
|
||||
@@ -355,9 +364,9 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func_apply, "trade_dir")
|
||||
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
|
||||
# NOTE: these indicator is designed for order execution, so the
|
||||
decision: List[Order] = outer_trade_decision.get_decision()
|
||||
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
|
||||
if len(decision) == 0:
|
||||
self.order_indicator.assign("amount", {})
|
||||
else:
|
||||
@@ -372,7 +381,7 @@ class Indicator:
|
||||
decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
Get the base volume and price information
|
||||
All the base price values are rooted from this function
|
||||
@@ -412,31 +421,35 @@ class Indicator:
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# remove zero and negative values.
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
if isinstance(volume_s, (int, float, np.number)):
|
||||
volume_s = idd.SingleData(volume_s, [trade_start_time])
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
volume_s = volume_s.reindex(price_s.index)
|
||||
elif agg == "twap":
|
||||
volume_s = idd.SingleData(1, price_s.index)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
base_volume = volume_s.sum()
|
||||
base_price = (price_s * volume_s).sum() / base_volume
|
||||
return base_price, base_volume
|
||||
|
||||
def _agg_base_price(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
# NOTE:!!!!
|
||||
# Strong assumption!!!!!!
|
||||
@@ -444,7 +457,7 @@ class Indicator:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_order_indicators : List[Dict[str, pd.Series]]
|
||||
inner_order_indicators : List[BaseOrderIndicator]
|
||||
the indicators of account of inner executor
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
a list of decisions according to inner_order_indicators
|
||||
@@ -489,14 +502,17 @@ class Indicator:
|
||||
bv_new = idd.SingleData(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = idd.concat(bp_all, axis=1)
|
||||
bv_all = idd.concat(bv_all, axis=1)
|
||||
bp_all_multi_data = idd.concat(bp_all, axis=1)
|
||||
bv_all_multi_data = idd.concat(bv_all, axis=1)
|
||||
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
base_volume = bv_all_multi_data.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
||||
self.order_indicator.assign(
|
||||
"base_price",
|
||||
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
|
||||
)
|
||||
|
||||
def _agg_order_price_advantage(self):
|
||||
def _agg_order_price_advantage(self) -> None:
|
||||
def if_empty_func(trade_price):
|
||||
return trade_price.empty
|
||||
|
||||
@@ -513,12 +529,12 @@ class Indicator:
|
||||
|
||||
def agg_order_indicators(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
indicator_config={},
|
||||
):
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._update_trade_amount(outer_trade_decision)
|
||||
self._update_order_fulfill_rate()
|
||||
@@ -526,71 +542,66 @@ class Indicator:
|
||||
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
|
||||
self._agg_order_price_advantage()
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(ffr):
|
||||
return ffr.mean()
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr: ffr.mean(),
|
||||
)
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(ffr, deal_amount):
|
||||
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(ffr, trade_value):
|
||||
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(pa):
|
||||
return pa.mean()
|
||||
|
||||
return self.order_indicator.transfer(lambda pa: pa.mean())
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(pa, deal_amount):
|
||||
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(pa, trade_value):
|
||||
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_positive_rate(self):
|
||||
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
|
||||
def func(pa):
|
||||
return (pa > 0).sum() / pa.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_deal_amount(self):
|
||||
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
|
||||
def func(deal_amount):
|
||||
return deal_amount.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_value(self):
|
||||
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
|
||||
def func(trade_value):
|
||||
return trade_value.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_order_count(self):
|
||||
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
|
||||
def func(amount):
|
||||
return amount.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
||||
def cal_trade_indicators(
|
||||
self,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
freq: str,
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
show_indicator = indicator_config.get("show_indicator", False)
|
||||
ffr_config = indicator_config.get("ffr_config", {})
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
@@ -608,22 +619,22 @@ class Indicator:
|
||||
self.trade_indicator["count"] = order_count
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq,
|
||||
trade_start_time,
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
),
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
if raw:
|
||||
return self.order_indicator
|
||||
return self.order_indicator.to_series()
|
||||
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
|
||||
return self.order_indicator if raw else self.order_indicator.to_series()
|
||||
|
||||
def get_trade_indicator(self):
|
||||
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
|
||||
return self.trade_indicator
|
||||
|
||||
def generate_trade_indicators_dataframe(self):
|
||||
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
|
||||
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
||||
|
||||
@@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
@@ -39,13 +39,14 @@ class SignalWCache(Signal):
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be
|
||||
automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
@@ -56,8 +57,8 @@ class SignalWCache(Signal):
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not align with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
@@ -65,7 +66,7 @@ class SignalWCache(Signal):
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
@@ -73,7 +74,7 @@ class ModelSignal(SignalWCache):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
def _update_model(self) -> None:
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
|
||||
@@ -149,6 +149,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
"""
|
||||
# potential performance issue
|
||||
assert self.level_infra is not None
|
||||
|
||||
day_start = pd.Timestamp(self.start_time.date())
|
||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
||||
@@ -182,8 +184,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left = bisect.bisect_right(self._calendar, start_time) - 1
|
||||
right = bisect.bisect_right(self._calendar, end_time) - 1
|
||||
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
@@ -201,14 +203,14 @@ class TradeCalendarManager:
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs) -> None:
|
||||
def reset_infra(self, **kwargs: Any) -> None:
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
|
||||
@@ -339,7 +339,7 @@ def long_short_backtest(
|
||||
for stock in long_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
long_profit.append(0)
|
||||
else:
|
||||
@@ -348,17 +348,17 @@ def long_short_backtest(
|
||||
for stock in short_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
short_profit.append(0)
|
||||
else:
|
||||
short_profit.append(-profit)
|
||||
short_profit.append(profit * -1)
|
||||
|
||||
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
||||
# exclude the suspend stock
|
||||
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
all_profit.append(0)
|
||||
else:
|
||||
|
||||
@@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC):
|
||||
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
|
||||
return _calendar[si : ei + 1]
|
||||
|
||||
def locate_index(self, start_time, end_time, freq, future=False):
|
||||
def locate_index(
|
||||
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
|
||||
):
|
||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : pd.Timestamp
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end_time : pd.Timestamp
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
|
||||
@@ -248,7 +248,7 @@ def load_orders(
|
||||
Order(
|
||||
row["instrument"],
|
||||
row["amount"],
|
||||
int(row["order_type"]),
|
||||
OrderDir(int(row["order_type"])),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
@@ -122,7 +122,10 @@ class BaseStrategy:
|
||||
self.outer_trade_decision = outer_trade_decision
|
||||
|
||||
@abstractmethod
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
def generate_trade_decision(
|
||||
self,
|
||||
execute_result: list = None,
|
||||
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
|
||||
"""Generate trade decision in each trading bar
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -9,6 +9,8 @@ Motivation of index_data
|
||||
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Tuple, Union, Callable, List
|
||||
import bisect
|
||||
|
||||
@@ -16,7 +18,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
||||
def concat(data_list: Union[SingleData], axis=0) -> MultiData:
|
||||
"""concat all SingleData by index.
|
||||
TODO: now just for SingleData.
|
||||
|
||||
@@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
||||
raise ValueError(f"axis must be 0 or 1")
|
||||
|
||||
|
||||
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
|
||||
def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:
|
||||
"""concat all SingleData by new index.
|
||||
|
||||
Parameters
|
||||
@@ -554,7 +556,7 @@ class SingleData(IndexData):
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
)
|
||||
|
||||
def reindex(self, index: Index, fill_value=np.NaN):
|
||||
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
|
||||
"""reindex data and fill the missing value with np.NaN.
|
||||
|
||||
Parameters
|
||||
@@ -580,7 +582,7 @@ class SingleData(IndexData):
|
||||
pass
|
||||
return SingleData(tmp_data, index)
|
||||
|
||||
def add(self, other: "SingleData", fill_value=0):
|
||||
def add(self, other: SingleData, fill_value=0):
|
||||
# TODO: add and __add__ are a little confusing.
|
||||
# This could be a more general
|
||||
common_index = self.index | other.index
|
||||
|
||||
Reference in New Issue
Block a user