1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Backtest Mypy (#1130)

* Done

* Fix test errors

* Revert profit_attribution.py

* Minor

* A minor update on collect_data type hint

* Resolve PR comments

* Use black to format code

* Fix CI errors
This commit is contained in:
Huoran Li
2022-06-28 22:16:46 +08:00
committed by GitHub
parent 9bf3423a64
commit 23c657a7a2
17 changed files with 363 additions and 315 deletions

View File

@@ -1,6 +1,6 @@
[mypy]
exclude = (?x)(
^qlib/backtest
^qlib/backtest/high_performance_ds\.py$
| ^qlib/contrib
| ^qlib/data
| ^qlib/model

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import copy
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
import pandas as pd
@@ -23,7 +23,6 @@ from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .position import Position
from .utils import CommonInfrastructure
# make import more user-friendly by adding `from qlib.backtest import STH`
@@ -44,7 +43,7 @@ def get_exchange(
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
**kwargs,
**kwargs: Any,
) -> Exchange:
"""get_exchange
@@ -52,14 +51,15 @@ def get_exchange(
----------
# exchange related arguments
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
exchange: Exchange
It could be None or any types that are acceptable by `init_instance_by_config`.
freq: str
frequency of data.
start_time: Union[pd.Timestamp, str]
closed start time for backtest.
end_time: Union[pd.Timestamp, str]
closed end time for backtest.
codes: list|str
codes: Union[list, str]
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
subscribe_fields: list
subscribe fields.
@@ -151,28 +151,24 @@ def create_account_instance(
Postion type.
"""
if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account}
init_cash = account
position_dict = {}
elif isinstance(account, dict):
init_cash = account["cash"]
del account["cash"]
pos_kwargs = {
"init_cash": init_cash,
"position_dict": account,
}
init_cash = account.pop("cash")
position_dict = account
else:
raise ValueError("account must be in (int, float, Position)")
raise ValueError("account must be in (int, float, dict)")
kwargs = {
"init_cash": account,
"benchmark_config": {
return Account(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
"pos_type": pos_type,
}
kwargs.update(pos_kwargs)
return Account(**kwargs)
)
def get_strategy_executor(
@@ -181,7 +177,7 @@ def get_strategy_executor(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
@@ -222,7 +218,7 @@ def backtest(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
@@ -285,7 +281,7 @@ def collect_data(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
@@ -339,7 +335,7 @@ def format_decisions(
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
res = (cur_freq, [])
res: Tuple[str, list] = (cur_freq, [])
last_dec_idx = 0
for i, dec in enumerate(decisions[1:], 1):
if dec.strategy.trade_calendar.get_freq() == cur_freq:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import copy
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple, cast
import pandas as pd
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
from .decision import BaseTradeDecision, Order
from .exchange import Exchange
from .high_performance_ds import BaseOrderIndicator
from .position import BasePosition
from .report import Indicator, PortfolioMetrics
@@ -104,7 +105,7 @@ class Account:
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.benchmark_config = None # avoid no attribute error
self.benchmark_config: dict = {} # avoid no attribute error
self.init_vars(init_cash, position_dict, freq, benchmark_config)
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
@@ -124,8 +125,8 @@ class Account:
self.accum_info = AccumulatedInfo()
# 2) following variables are not shared between layers
self.portfolio_metrics = None
self.hist_positions = {}
self.portfolio_metrics: Optional[PortfolioMetrics] = None
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
self.reset(freq=freq, benchmark_config=benchmark_config)
def is_port_metr_enabled(self) -> bool:
@@ -171,7 +172,7 @@ class Account:
self.reset_report(self.freq, self.benchmark_config)
def get_hist_positions(self) -> dict:
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
return self.hist_positions
def get_cash(self) -> float:
@@ -230,13 +231,15 @@ class Account:
"""
# update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
assert self.current_position is not None
if not self.current_position.skip_update():
stock_list = self.current_position.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
self.current_position.update_stock_price(stock_id=code, price=bar_close)
# update holding day count
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
@@ -249,6 +252,8 @@ class Account:
# for the first trade date, account_value - init_cash
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
# get last_account_value, last_total_cost, last_total_turnover
assert self.portfolio_metrics is not None
if self.portfolio_metrics.is_empty():
last_account_value = self.init_cash
last_total_cost = 0
@@ -299,9 +304,9 @@ class Account:
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
) -> None:
"""update trade indicators and order indicators in each bar end"""
@@ -335,9 +340,9 @@ class Account:
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
) -> None:
"""update account at each trading bar step
@@ -398,6 +403,7 @@ class Account:
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
"""get the history portfolio_metrics and positions instance"""
if self.is_port_metr_enabled():
assert self.portfolio_metrics is not None
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
_positions = self.get_hist_positions()
return _portfolio_metrics, _positions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
import pandas as pd
@@ -36,10 +36,13 @@ def backtest_loop(
indicator: Indicator
it computes the trading indicator
"""
return_value = {}
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass
return return_value.get("portfolio_metrics"), return_value.get("indicator")
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
def collect_data_loop(

View File

@@ -7,7 +7,7 @@ from abc import abstractmethod
from enum import IntEnum
# try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal
@@ -24,8 +24,11 @@ import numpy as np
import pandas as pd
DecisionType = TypeVar("DecisionType")
class OrderDir(IntEnum):
# Order direction
# Order direction
SELL = 0
BUY = 1
@@ -65,7 +68,7 @@ class Order:
# - not tradable: the deal_amount == 0 , factor is None
# - the stock is suspended and the entire order fails. No cost for this order
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
factor: Optional[float] = None
# TODO:
@@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange):
return max(val_start, start_time), min(val_end, end_time)
class BaseTradeDecision:
class BaseTradeDecision(Generic[DecisionType]):
"""
Trade decisions ara made by strategy and executed by executor
@@ -316,20 +319,21 @@ class BaseTradeDecision:
"""
self.strategy = strategy
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
if isinstance(trade_range, Tuple):
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
self.total_step: Optional[int] = None
if isinstance(trade_range, tuple):
# for Tuple[int, int]
trade_range = IdxTradeRange(*trade_range)
self.trade_range: TradeRange = trade_range
self.trade_range: Optional[TradeRange] = trade_range
def get_decision(self) -> List[object]:
def get_decision(self) -> List[DecisionType]:
"""
get the **concrete decision** (e.g. execution orders)
This will be called by the inner strategy
Returns
-------
List[object]:
List[DecisionType:
The decision result. Typically it is some orders
Example:
[]:
@@ -363,13 +367,13 @@ class BaseTradeDecision:
# purpose 2)
return self.strategy.update_trade_decision(self, trade_calendar)
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
if self.trade_range is not None:
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
else:
raise NotImplementedError("The decision didn't provide an index range")
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
"""
return the expected step range for limiting the decision execution time
Both left and right are **closed**
@@ -421,6 +425,7 @@ class BaseTradeDecision:
if getattr(self, "total_step", None) is not None:
# if `self.update` is called.
# Then the _start_idx, _end_idx should be clipped
assert self.total_step is not None
if _start_idx < 0 or _end_idx >= self.total_step:
logger = get_module_logger("decision")
logger.warning(
@@ -516,7 +521,7 @@ class BaseTradeDecision:
inner_trade_decision.trade_range = self.trade_range
class EmptyTradeDecision(BaseTradeDecision):
class EmptyTradeDecision(BaseTradeDecision[object]):
def get_decision(self) -> List[object]:
return []
@@ -524,23 +529,24 @@ class EmptyTradeDecision(BaseTradeDecision):
return True
class TradeDecisionWO(BaseTradeDecision):
class TradeDecisionWO(BaseTradeDecision[Order]):
"""
Trade Decision (W)ith (O)rder.
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time()
for o in order_list:
assert isinstance(o, Order)
if o.start_time is None:
o.start_time = start
if o.end_time is None:
o.end_time = end
def get_decision(self) -> List[object]:
def get_decision(self) -> List[Order]:
return self.order_list
def __repr__(self) -> str:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from ..utils.index_data import IndexData
@@ -42,7 +42,7 @@ class Exchange:
impact_cost: float = 0.0,
extra_quote: pd.DataFrame = None,
quote_cls: Type[BaseQuote] = NumpyQuote,
**kwargs,
**kwargs: Any,
) -> None:
"""__init__
:param freq: frequency of data
@@ -141,7 +141,7 @@ class Exchange:
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
@@ -150,7 +150,7 @@ class Exchange:
deal_price = "$" + deal_price
self.buy_price = self.sell_price = deal_price
elif isinstance(deal_price, (tuple, list)):
self.buy_price, self.sell_price = deal_price
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -167,10 +167,10 @@ class Exchange:
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple)
for exp in limit_threshold:
necessary_fields.add(exp)
all_fields = necessary_fields | set(vol_lt_fields)
all_fields = list(all_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
self.all_fields = all_fields
@@ -249,9 +249,9 @@ class Exchange:
LT_FLT = "float" # float
LT_NONE = "none" # none
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
if isinstance(limit_threshold, Tuple):
if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
return self.LT_FLT
@@ -268,14 +268,16 @@ class Exchange:
self.quote_df["limit_sell"] = False
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
"""
preprocess the volume limit.
get the fields need to get from qlib.
@@ -340,11 +342,11 @@ class Exchange:
if direction is None:
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return buy_limit or sell_limit
return bool(buy_limit or sell_limit)
elif direction == Order.BUY:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
elif direction == Order.SELL:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
else:
raise ValueError(f"direction {direction} is not supported!")
@@ -382,7 +384,7 @@ class Exchange:
order: Order,
trade_account: Account = None,
position: BasePosition = None,
dealt_order_amount: defaultdict = defaultdict(float),
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
Deal order when the actual transaction
@@ -426,9 +428,10 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
field: str,
method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
def get_close(
self,
@@ -444,10 +447,10 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: str = "sum",
method: Optional[str] = "sum",
) -> float:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
def get_deal_price(
self,
@@ -455,7 +458,7 @@ class Exchange:
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: OrderDir,
method: str = "ts_data_last",
method: Optional[str] = "ts_data_last",
) -> float:
if direction == OrderDir.SELL:
pstr = self.sell_price
@@ -469,7 +472,7 @@ class Exchange:
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method)
return deal_price
return cast(float, deal_price)
def get_factor(
self,
@@ -544,7 +547,7 @@ class Exchange:
)
return amount_dict
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
@@ -572,7 +575,7 @@ class Exchange:
current_position: dict,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> list:
) -> List[Order]:
"""
Note: some future information is used in this function
Parameter:
@@ -681,6 +684,7 @@ class Exchange:
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
else:
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
assert factor is not None
return factor
def get_amount_of_trade_unit(
@@ -718,12 +722,12 @@ class Exchange:
def round_amount_by_trade_unit(
self,
deal_amount,
deal_amount: float,
factor: float = None,
stock_id: str = None,
start_time=None,
end_time=None,
):
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
"""Parameter
Please refer to the docs of get_amount_of_trade_unit
deal_amount : float, adjusted amount
@@ -741,7 +745,7 @@ class Exchange:
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
"""parse the capacity limit string and return the actual amount of orders that can be executed.
NOTE:
this function will change the order.deal_amount **inplace**
@@ -753,15 +757,12 @@ class Exchange:
dealt_order_amount : dict
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
"""
if order.direction == Order.BUY:
vol_limit = self.buy_vol_limit
elif order.direction == Order.SELL:
vol_limit = self.sell_vol_limit
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
if vol_limit is None:
return order.deal_amount
vol_limit_num = []
vol_limit_num: List[float] = []
for limit in vol_limit:
assert isinstance(limit, tuple)
if limit[0] == "current":
@@ -772,7 +773,7 @@ class Exchange:
field=limit[1],
method="sum",
)
vol_limit_num.append(limit_value)
vol_limit_num.append(cast(float, limit_value))
elif limit[0] == "cum":
limit_value = self.quote.get_data(
order.stock_id,
@@ -790,12 +791,14 @@ class Exchange:
if vol_limit_min < orig_deal_amount:
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
return None
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
position : cash
cash : float
cost_ratio : float
Return
@@ -803,7 +806,7 @@ class Exchange:
float
the real order amount after cash limit for buying.
"""
max_trade_amount = 0
max_trade_amount = 0.0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / cost_ratio + self.min_cost
@@ -897,7 +900,7 @@ class Exchange:
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
raise NotImplementedError("order type {} error".format(order.type))
raise NotImplementedError("order direction {} error".format(order.direction))
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * cost_ratio, self.min_cost)

View File

@@ -4,7 +4,7 @@ import copy
from abc import abstractmethod
from collections import defaultdict
from types import GeneratorType
from typing import Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Tuple, Union, cast
import pandas as pd
@@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy
from ..utils import init_instance_by_config
from .decision import BaseTradeDecision, Order
from .exchange import Exchange
from .utils import (
BaseInfrastructure,
CommonInfrastructure,
LevelInfrastructure,
TradeCalendarManager,
get_start_end_idx,
)
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
class BaseExecutor:
@@ -39,8 +33,8 @@ class BaseExecutor:
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None,
settle_type=BasePosition.ST_NO, # TODO: add typehint
**kwargs,
settle_type: str = BasePosition.ST_NO,
**kwargs: Any,
) -> None:
"""
Parameters
@@ -127,10 +121,10 @@ class BaseExecutor:
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
# record deal order amount in one day
self.dealt_order_amount = defaultdict(float)
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
self.deal_day = None
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
"""
reset infrastructure for trading
- reset trade_account
@@ -141,14 +135,15 @@ class BaseExecutor:
self.common_infra.update(common_infra)
if common_infra.has("trade_account"):
if copy_trade_account:
# NOTE: there is a trick in the code.
# shallow copy is used instead of deepcopy.
# 1. So positions are shared
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
else:
self.trade_account: Account = common_infra.get("trade_account")
# NOTE: there is a trick in the code.
# shallow copy is used instead of deepcopy.
# 1. So positions are shared
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
self.trade_account: Account = (
copy.copy(common_infra.get("trade_account"))
if copy_trade_account
else common_infra.get("trade_account")
)
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
@property
@@ -164,7 +159,7 @@ class BaseExecutor:
"""
return self.level_infra.get("trade_calendar")
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -200,20 +195,17 @@ class BaseExecutor:
execute_result : List[object]
the executed result for trade decision
"""
return_value = {}
return_value: dict = {}
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
pass
return return_value.get("execute_result")
return cast(list, return_value.get("execute_result"))
@abstractmethod
def _collect_data(
self,
trade_decision: BaseTradeDecision,
level: int = 0,
) -> Union[
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
Tuple[List[object], dict],
]:
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
"""
Please refer to the doc of collect_data
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
@@ -235,7 +227,7 @@ class BaseExecutor:
trade_decision: BaseTradeDecision,
return_value: dict = None,
level: int = 0,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training
his function will make a step forward
@@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor):
skip_empty_decision: bool = True,
align_range_limit: bool = True,
common_infra: CommonInfrastructure = None,
**kwargs,
**kwargs: Any,
) -> None:
"""
Parameters
@@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor):
self,
trade_decision: BaseTradeDecision,
level: int = 0,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
) -> Generator[Any, Any, Tuple[List[object], dict]]:
execute_result = []
inner_order_indicators = []
decision_list = []
@@ -493,7 +485,7 @@ class NestedExecutor(BaseExecutor):
the execution result of inner task
"""
def get_all_executors(self) -> List[object]:
def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors, including self and inner_executor.get_all_executors()"""
return [self, *self.inner_executor.get_all_executors()]
@@ -536,7 +528,7 @@ class SimulatorExecutor(BaseExecutor):
track_data: bool = False,
common_infra: CommonInfrastructure = None,
trade_type: str = TT_SERIAL,
**kwargs,
**kwargs: Any,
) -> None:
"""
Parameters
@@ -598,7 +590,7 @@ class SimulatorExecutor(BaseExecutor):
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
trade_start_time, _ = self.trade_calendar.get_step_time()
execute_result = []
execute_result: list = []
for order in self._get_order_iterator(trade_decision):
# execute the order.

View File

@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import inspect
import logging
from collections import OrderedDict
from functools import lru_cache
from typing import Callable, Dict, Iterable, List, Text, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
import numpy as np
import pandas as pd
@@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value
class BaseQuote:
def __init__(self, quote_df: pd.DataFrame, freq):
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
self.logger = get_module_logger("online operator", level=logging.INFO)
def get_all_stock(self) -> Iterable:
@@ -39,7 +41,7 @@ class BaseQuote:
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
field: Union[str],
method: Union[str, None] = None,
method: Optional[str] = None,
) -> Union[None, int, float, bool, IndexData]:
"""get the specific field of stock data during start time and end_time,
and apply method to the data.
@@ -99,7 +101,7 @@ class BaseQuote:
class PandasQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq):
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
super().__init__(quote_df=quote_df, freq=freq)
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
@@ -124,7 +126,7 @@ class PandasQuote(BaseQuote):
class NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
"""NumpyQuote
Parameters
@@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote):
data = self._agg_data(data, method)
return data
def _agg_data(self, data: IndexData, method):
@staticmethod
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
"""Agg data by specific method."""
# FIXME: why not call the method of data directly?
if method == "sum":
@@ -224,31 +227,31 @@ class BaseSingleMetric:
"""
raise NotImplementedError(f"Please implement the `__init__` method")
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__add__` method")
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
return self + other
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__sub__` method")
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__rsub__` method")
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__mul__` method")
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__truediv__` method")
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __eq__(self, other: object) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__eq__` method")
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__gt__` method")
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__lt__` method")
def __len__(self) -> int:
@@ -265,7 +268,7 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `count` method")
def abs(self) -> "BaseSingleMetric":
def abs(self) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `abs` method")
@property
@@ -274,18 +277,18 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `empty` method")
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
"""Replace np.NaN with fill_value in two metrics and add them."""
raise NotImplementedError(f"Please implement the `add` method")
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
def replace(self, replace_dict: dict) -> BaseSingleMetric:
"""Replace the value of metric according to replace_dict."""
raise NotImplementedError(f"Please implement the `replace` method")
def apply(self, func: dict) -> "BaseSingleMetric":
"""Replace the value of metric with func(metric).
def apply(self, func: Callable) -> BaseSingleMetric:
"""Replace the value of metric with func (metric).
Currently, the func is only qlib/backtest/order/Order.parse_dir.
"""
@@ -304,11 +307,11 @@ class BaseOrderIndicator:
to inherit the BaseSingleMetric.
"""
def __init__(self, data):
self.data = data
def __init__(self):
self.data = {} # will be created in the subclass
self.logger = get_module_logger("online operator")
def assign(self, col: str, metric: Union[dict, pd.Series]):
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
"""assign one metric.
Parameters
@@ -328,7 +331,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'assign' method")
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
"""compute new metric with existing metrics.
Parameters
@@ -352,6 +355,7 @@ class BaseOrderIndicator:
tmp_metric = func(**func_kwargs)
if new_col is not None:
self.data[new_col] = tmp_metric
return None
else:
return tmp_metric
@@ -372,7 +376,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
def get_index_data(self, metric) -> SingleData:
def get_index_data(self, metric: str) -> SingleData:
"""get one metric with the format of SingleData
Parameters
@@ -389,7 +393,12 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_index_data' method")
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
"""sum indicators with the same metrics.
and assign to the order_indicator(BaseOrderIndicator).
NOTE: indicators could be a empty list when orders in lower level all fail.
@@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
def index(self):
return list(self.metric.index)
def add(self, other, fill_value=None):
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
other = cast(PandasSingleMetric, other)
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
def replace(self, replace_dict: dict):
def replace(self, replace_dict: dict) -> PandasSingleMetric:
return self.__class__(self.metric.replace(replace_dict))
def apply(self, func: Callable):
def apply(self, func: Callable) -> PandasSingleMetric:
return self.__class__(self.metric.apply(func))
def reindex(self, index, fill_value):
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
def __repr__(self):
@@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
Str is the name of metric.
"""
def __init__(self):
def __init__(self) -> None:
super(PandasOrderIndicator, self).__init__()
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
def assign(self, col: str, metric: Union[dict, pd.Series]):
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
self.data[col] = PandasSingleMetric(metric)
def get_index_data(self, metric):
def get_index_data(self, metric: str) -> SingleData:
if metric in self.data:
return idd.SingleData(self.data[metric].metric)
else:
@@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
return {k: v.metric for k, v in self.data.items()}
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
@@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
Str is the name of metric.
"""
def __init__(self):
def __init__(self) -> None:
super(NumpyOrderIndicator, self).__init__()
self.data: Dict[str, SingleData] = OrderedDict()
def assign(self, col: str, metric: dict):
def assign(self, col: str, metric: dict) -> None:
self.data[col] = idd.SingleData(metric)
def get_index_data(self, metric):
def get_index_data(self, metric: str) -> SingleData:
if metric in self.data:
return self.data[metric]
else:
@@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
return tmp_metric_dict
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
# get all index(stock_id)
stocks = set()
stock_set: set = set()
for indicator in indicators:
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
stocks = list(stocks)
stocks.sort()
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
stocks = sorted(list(stock_set))
# add metric by index
if isinstance(metrics, str):

View File

@@ -3,7 +3,7 @@
from datetime import timedelta
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
import numpy as np
import pandas as pd
@@ -18,9 +18,9 @@ class BasePosition:
Please refer to the `Position` class for the position
"""
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
self._settle_type = self.ST_NO
self.position = {}
self.position: dict = {}
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
pass
@@ -96,13 +96,13 @@ class BasePosition:
def calculate_value(self) -> float:
raise NotImplementedError(f"Please implement the `calculate_value` method")
def get_stock_list(self) -> List:
def get_stock_list(self) -> List[str]:
"""
Get the list of stocks in the position.
"""
raise NotImplementedError(f"Please implement the `get_stock_list` method")
def get_stock_price(self, code) -> float:
def get_stock_price(self, code: str) -> float:
"""
get the latest price of the stock
@@ -113,7 +113,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_price` method")
def get_stock_amount(self, code) -> float:
def get_stock_amount(self, code: str) -> float:
"""
get the amount of the stock
@@ -144,7 +144,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_cash` method")
def get_stock_amount_dict(self) -> Dict:
def get_stock_amount_dict(self) -> dict:
"""
generate stock amount dict {stock_id : amount of stock}
@@ -155,7 +155,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
"""
generate stock weight dict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade step
@@ -174,7 +174,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
def add_count_all(self, bar) -> None:
def add_count_all(self, bar: str) -> None:
"""
Will be called at the end of each bar on each level
@@ -195,7 +195,7 @@ class BasePosition:
raise NotImplementedError(f"Please implement the `add_count_all` method")
ST_CASH = "cash"
ST_NO = None
ST_NO = "None" # String is more typehint friendly than None
def settle_start(self, settle_type: str) -> None:
"""
@@ -220,10 +220,10 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `settle_commit` method")
def __str__(self):
def __str__(self) -> str:
return self.__dict__.__str__()
def __repr__(self):
def __repr__(self) -> str:
return self.__dict__.__repr__()
@@ -532,7 +532,7 @@ class InfPosition(BasePosition):
def calculate_value(self) -> float:
raise NotImplementedError(f"InfPosition doesn't support calculating value")
def get_stock_list(self) -> list:
def get_stock_list(self) -> List[str]:
raise NotImplementedError(f"InfPosition doesn't support stock list position")
def get_stock_price(self, code: str) -> float:
@@ -545,10 +545,10 @@ class InfPosition(BasePosition):
def get_cash(self, include_settle: bool = False) -> float:
return np.inf
def get_stock_amount_dict(self) -> Dict:
def get_stock_amount_dict(self) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
def add_count_all(self, bar: str) -> None:

View File

@@ -4,7 +4,7 @@
import pathlib
from collections import OrderedDict
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
import numpy as np
import pandas as pd
@@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
class PortfolioMetrics:
@@ -38,7 +38,7 @@ class PortfolioMetrics:
update report
"""
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
"""
Parameters
----------
@@ -49,13 +49,17 @@ class PortfolioMetrics:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
print(
D.features(D.instruments('csi500'),
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
)
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional
@@ -70,25 +74,26 @@ class PortfolioMetrics:
self.init_vars()
self.init_bench(freq=freq, benchmark_config=benchmark_config)
def init_vars(self):
self.accounts = OrderedDict() # account position value for each trade time
self.returns = OrderedDict() # daily return rate for each trade time
self.total_turnovers = OrderedDict() # total turnover for each trade time
self.turnovers = OrderedDict() # turnover for each trade time
self.total_costs = OrderedDict() # total trade cost for each trade time
self.costs = OrderedDict() # trade cost rate for each trade time
self.values = OrderedDict() # value for each trade time
self.cashes = OrderedDict()
self.benches = OrderedDict()
self.latest_pm_time = None # pd.TimeStamp
def init_vars(self) -> None:
self.accounts: dict = OrderedDict() # account position value for each trade time
self.returns: dict = OrderedDict() # daily return rate for each trade time
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
self.turnovers: dict = OrderedDict() # turnover for each trade time
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
self.costs: dict = OrderedDict() # trade cost rate for each trade time
self.values: dict = OrderedDict() # value for each trade time
self.cashes: dict = OrderedDict()
self.benches: dict = OrderedDict()
self.latest_pm_time: Optional[pd.TimeStamp] = None
def init_bench(self, freq=None, benchmark_config=None):
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
if freq is not None:
self.freq = freq
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _cal_benchmark(self, benchmark_config, freq):
@staticmethod
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
if benchmark_config is None:
return None
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
@@ -110,7 +115,12 @@ class PortfolioMetrics:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def _sample_benchmark(
self,
bench: pd.Series,
trade_start_time: Union[str, pd.Timestamp],
trade_end_time: Union[str, pd.Timestamp],
) -> Optional[float]:
if self.bench is None:
return None
@@ -120,35 +130,35 @@ class PortfolioMetrics:
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
return 0.0 if _ret is None else _ret - 1
def is_empty(self):
def is_empty(self) -> bool:
return len(self.accounts) == 0
def get_latest_date(self):
def get_latest_date(self) -> pd.Timestamp:
return self.latest_pm_time
def get_latest_account_value(self):
def get_latest_account_value(self) -> float:
return self.accounts[self.latest_pm_time]
def get_latest_total_cost(self):
def get_latest_total_cost(self) -> Any:
return self.total_costs[self.latest_pm_time]
def get_latest_total_turnover(self):
def get_latest_total_turnover(self) -> Any:
return self.total_turnovers[self.latest_pm_time]
def update_portfolio_metrics_record(
self,
trade_start_time=None,
trade_end_time=None,
account_value=None,
cash=None,
return_rate=None,
total_turnover=None,
turnover_rate=None,
total_cost=None,
cost_rate=None,
stock_value=None,
bench_value=None,
):
trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time: Union[str, pd.Timestamp] = None,
account_value: float = None,
cash: float = None,
return_rate: float = None,
total_turnover: float = None,
turnover_rate: float = None,
total_cost: float = None,
cost_rate: float = None,
stock_value: float = None,
bench_value: float = None,
) -> None:
# check data
if None in [
trade_start_time,
@@ -185,7 +195,7 @@ class PortfolioMetrics:
self.latest_pm_time = trade_start_time
# finish pm update in each step
def generate_portfolio_metrics_dataframe(self):
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
pm = pd.DataFrame()
pm["account"] = pd.Series(self.accounts)
pm["return"] = pd.Series(self.returns)
@@ -199,19 +209,18 @@ class PortfolioMetrics:
pm.index.name = "datetime"
return pm
def save_portfolio_metrics(self, path):
def save_portfolio_metrics(self, path: str) -> None:
r = self.generate_portfolio_metrics_dataframe()
r.to_csv(path)
def load_portfolio_metrics(self, path):
def load_portfolio_metrics(self, path: str) -> None:
"""load pm from a file
should have format like
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
:param
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
with path.open("rb") as f:
with pathlib.Path(path).open("rb") as f:
r = pd.read_csv(f, index_col=0)
r.index = pd.DatetimeIndex(r.index)
@@ -261,30 +270,30 @@ class Indicator:
"""
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
self.order_indicator_cls = order_indicator_cls
# order indicator is metrics for a single order for a specific step
self.order_indicator_his = OrderedDict()
self.order_indicator_his: dict = OrderedDict()
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
# trade indicator is metrics for all orders for a specific step
self.trade_indicator_his = OrderedDict()
self.trade_indicator: Dict[str, float] = OrderedDict()
self.trade_indicator_his: dict = OrderedDict()
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
self._trade_calendar = None
# def reset(self, trade_calendar: TradeCalendarManager):
def reset(self):
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
def reset(self) -> None:
self.order_indicator = self.order_indicator_cls()
self.trade_indicator = OrderedDict()
# self._trade_calendar = trade_calendar
def record(self, trade_start_time):
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
def _update_order_trade_info(self, trade_info: list):
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
amount = dict()
deal_amount = dict()
trade_price = dict()
@@ -313,7 +322,7 @@ class Indicator:
self.order_indicator.assign("trade_dir", trade_dir)
self.order_indicator.assign("pa", pa)
def _update_order_fulfill_rate(self):
def _update_order_fulfill_rate(self) -> None:
def func(deal_amount, amount):
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
@@ -322,11 +331,11 @@ class Indicator:
self.order_indicator.transfer(func, "ffr")
def update_order_indicators(self, trade_info: list):
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
self._update_order_trade_info(trade_info=trade_info)
self._update_order_fulfill_rate()
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
# calculate total trade amount with each inner order indicator.
def trade_amount_func(deal_amount, trade_price):
return deal_amount * trade_price
@@ -355,9 +364,9 @@ class Indicator:
self.order_indicator.transfer(func_apply, "trade_dir")
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
# NOTE: these indicator is designed for order execution, so the
decision: List[Order] = outer_trade_decision.get_decision()
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
if len(decision) == 0:
self.order_indicator.assign("amount", {})
else:
@@ -372,7 +381,7 @@ class Indicator:
decision: BaseTradeDecision,
trade_exchange: Exchange,
pa_config: dict = {},
):
) -> Tuple[Optional[float], Optional[float]]:
"""
Get the base volume and price information
All the base price values are rooted from this function
@@ -412,31 +421,35 @@ class Indicator:
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
# remove zero and negative values.
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
assert isinstance(price_s, idd.SingleData)
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
# ~(np.NaN < 1e-8) -> ~(False) -> True
assert isinstance(price_s, idd.SingleData)
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float, np.number)):
volume_s = idd.SingleData(volume_s, [trade_start_time])
assert isinstance(volume_s, idd.SingleData)
volume_s = volume_s.reindex(price_s.index)
elif agg == "twap":
volume_s = idd.SingleData(1, price_s.index)
else:
raise NotImplementedError(f"This type of input is not supported")
assert isinstance(volume_s, idd.SingleData)
base_volume = volume_s.sum()
base_price = (price_s * volume_s).sum() / base_volume
return base_price, base_volume
def _agg_base_price(
self,
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
trade_exchange: Exchange,
pa_config: dict = {},
):
) -> None:
"""
# NOTE:!!!!
# Strong assumption!!!!!!
@@ -444,7 +457,7 @@ class Indicator:
Parameters
----------
inner_order_indicators : List[Dict[str, pd.Series]]
inner_order_indicators : List[BaseOrderIndicator]
the indicators of account of inner executor
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
a list of decisions according to inner_order_indicators
@@ -489,14 +502,17 @@ class Indicator:
bv_new = idd.SingleData(bv_new)
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = idd.concat(bp_all, axis=1)
bv_all = idd.concat(bv_all, axis=1)
bp_all_multi_data = idd.concat(bp_all, axis=1)
bv_all_multi_data = idd.concat(bv_all, axis=1)
base_volume = bv_all.sum(axis=1)
base_volume = bv_all_multi_data.sum(axis=1)
self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
self.order_indicator.assign(
"base_price",
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
)
def _agg_order_price_advantage(self):
def _agg_order_price_advantage(self) -> None:
def if_empty_func(trade_price):
return trade_price.empty
@@ -513,12 +529,12 @@ class Indicator:
def agg_order_indicators(
self,
inner_order_indicators: List[Dict[str, pd.Series]],
inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
outer_trade_decision: BaseTradeDecision,
trade_exchange: Exchange,
indicator_config={},
):
indicator_config: dict = {},
) -> None:
self._agg_order_trade_info(inner_order_indicators)
self._update_trade_amount(outer_trade_decision)
self._update_order_fulfill_rate()
@@ -526,71 +542,66 @@ class Indicator:
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
self._agg_order_price_advantage()
def _cal_trade_fulfill_rate(self, method="mean"):
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean":
def func(ffr):
return ffr.mean()
return self.order_indicator.transfer(
lambda ffr: ffr.mean(),
)
elif method == "amount_weighted":
def func(ffr, deal_amount):
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
return self.order_indicator.transfer(
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
)
elif method == "value_weighted":
def func(ffr, trade_value):
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
return self.order_indicator.transfer(
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
)
else:
raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_price_advantage(self, method="mean"):
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean":
def func(pa):
return pa.mean()
return self.order_indicator.transfer(lambda pa: pa.mean())
elif method == "amount_weighted":
def func(pa, deal_amount):
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
return self.order_indicator.transfer(
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
)
elif method == "value_weighted":
def func(pa, trade_value):
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
return self.order_indicator.transfer(
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
)
else:
raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_positive_rate(self):
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
def func(pa):
return (pa > 0).sum() / pa.count()
return self.order_indicator.transfer(func)
def _cal_deal_amount(self):
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
def func(deal_amount):
return deal_amount.abs().sum()
return self.order_indicator.transfer(func)
def _cal_trade_value(self):
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
def func(trade_value):
return trade_value.abs().sum()
return self.order_indicator.transfer(func)
def _cal_trade_order_count(self):
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
def func(amount):
return amount.count()
return self.order_indicator.transfer(func)
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
def cal_trade_indicators(
self,
trade_start_time: Union[str, pd.Timestamp],
freq: str,
indicator_config: dict = {},
) -> None:
show_indicator = indicator_config.get("show_indicator", False)
ffr_config = indicator_config.get("ffr_config", {})
pa_config = indicator_config.get("pa_config", {})
@@ -608,22 +619,22 @@ class Indicator:
self.trade_indicator["count"] = order_count
if show_indicator:
print(
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
fulfill_rate,
price_advantage,
positive_rate,
),
)
def get_order_indicator(self, raw: bool = True):
if raw:
return self.order_indicator
return self.order_indicator.to_series()
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
return self.order_indicator if raw else self.order_indicator.to_series()
def get_trade_indicator(self):
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
return self.trade_indicator
def generate_trade_indicators_dataframe(self):
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")

View File

@@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
@@ -39,13 +39,14 @@ class SignalWCache(Signal):
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
The expected format of the signal is like the data below (the order of index is not important and can be
automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
@@ -56,8 +57,8 @@ class SignalWCache(Signal):
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not align with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
@@ -65,7 +66,7 @@ class SignalWCache(Signal):
class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset):
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
@@ -73,7 +74,7 @@ class ModelSignal(SignalWCache):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
def _update_model(self) -> None:
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update

View File

@@ -149,6 +149,8 @@ class TradeCalendarManager:
Tuple[int, int]:
"""
# potential performance issue
assert self.level_infra is not None
day_start = pd.Timestamp(self.start_time.date())
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
@@ -182,8 +184,8 @@ class TradeCalendarManager:
Tuple[int, int]:
the index of the range. **the left and right are closed**
"""
left = bisect.bisect_right(self._calendar, start_time) - 1
right = bisect.bisect_right(self._calendar, end_time) - 1
left = bisect.bisect_right(list(self._calendar), start_time) - 1
right = bisect.bisect_right(list(self._calendar), end_time) - 1
left -= self.start_index
right -= self.start_index
@@ -201,14 +203,14 @@ class TradeCalendarManager:
class BaseInfrastructure:
def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
self.reset_infra(**kwargs)
@abstractmethod
def get_support_infra(self) -> Set[str]:
raise NotImplementedError("`get_support_infra` is not implemented!")
def reset_infra(self, **kwargs) -> None:
def reset_infra(self, **kwargs: Any) -> None:
support_infra = self.get_support_infra()
for k, v in kwargs.items():
if k in support_infra:

View File

@@ -339,7 +339,7 @@ def long_short_backtest(
for stock in long_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
long_profit.append(0)
else:
@@ -348,17 +348,17 @@ def long_short_backtest(
for stock in short_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
short_profit.append(0)
else:
short_profit.append(-profit)
short_profit.append(profit * -1)
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
# exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
all_profit.append(0)
else:

View File

@@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC):
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
return _calendar[si : ei + 1]
def locate_index(self, start_time, end_time, freq, future=False):
def locate_index(
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
):
"""Locate the start time index and end time index in a calendar under certain frequency.
Parameters
----------
start_time : str
start_time : pd.Timestamp
start of the time range.
end_time : str
end_time : pd.Timestamp
end of the time range.
freq : str
time frequency, available: year/quarter/month/week/day.

View File

@@ -248,7 +248,7 @@ def load_orders(
Order(
row["instrument"],
row["amount"],
int(row["order_type"]),
OrderDir(int(row["order_type"])),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Generator, Optional
if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
@@ -122,7 +122,10 @@ class BaseStrategy:
self.outer_trade_decision = outer_trade_decision
@abstractmethod
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
def generate_trade_decision(
self,
execute_result: list = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
"""Generate trade decision in each trading bar
Parameters

View File

@@ -9,6 +9,8 @@ Motivation of index_data
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
"""
from __future__ import annotations
from typing import Dict, Tuple, Union, Callable, List
import bisect
@@ -16,7 +18,7 @@ import numpy as np
import pandas as pd
def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
def concat(data_list: Union[SingleData], axis=0) -> MultiData:
"""concat all SingleData by index.
TODO: now just for SingleData.
@@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
raise ValueError(f"axis must be 0 or 1")
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:
"""concat all SingleData by new index.
Parameters
@@ -554,7 +556,7 @@ class SingleData(IndexData):
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
)
def reindex(self, index: Index, fill_value=np.NaN):
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
"""reindex data and fill the missing value with np.NaN.
Parameters
@@ -580,7 +582,7 @@ class SingleData(IndexData):
pass
return SingleData(tmp_data, index)
def add(self, other: "SingleData", fill_value=0):
def add(self, other: SingleData, fill_value=0):
# TODO: add and __add__ are a little confusing.
# This could be a more general
common_index = self.index | other.index