From 23c657a7a2a79b92066733c80fd549c025d9cd80 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 28 Jun 2022 22:16:46 +0800 Subject: [PATCH] Backtest Mypy (#1130) * Done * Fix test errors * Revert profit_attribution.py * Minor * A minor update on collect_data type hint * Resolve PR comments * Use black to format code * Fix CI errors --- .mypy.ini | 2 +- qlib/backtest/__init__.py | 44 +++--- qlib/backtest/account.py | 30 ++-- qlib/backtest/backtest.py | 9 +- qlib/backtest/decision.py | 40 +++-- qlib/backtest/exchange.py | 73 ++++----- qlib/backtest/executor.py | 58 +++---- qlib/backtest/high_performance_ds.py | 105 ++++++++----- qlib/backtest/position.py | 30 ++-- qlib/backtest/report.py | 227 ++++++++++++++------------- qlib/backtest/signal.py | 15 +- qlib/backtest/utils.py | 10 +- qlib/contrib/evaluate.py | 8 +- qlib/data/data.py | 8 +- qlib/rl/data/pickle_styled.py | 2 +- qlib/strategy/base.py | 7 +- qlib/utils/index_data.py | 10 +- 17 files changed, 363 insertions(+), 315 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 195a0505f..d4baf0c33 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,6 +1,6 @@ [mypy] exclude = (?x)( - ^qlib/backtest + ^qlib/backtest/high_performance_ds\.py$ | ^qlib/contrib | ^qlib/data | ^qlib/model diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 20fbe14a4..622b07d35 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -5,7 +5,7 @@ from __future__ import annotations import copy from pathlib import Path -from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union import pandas as pd @@ -23,7 +23,6 @@ from ..utils import init_instance_by_config from .backtest import backtest_loop, collect_data_loop from .decision import Order from .exchange import Exchange -from .position import Position from .utils import CommonInfrastructure # make import more user-friendly by adding `from qlib.backtest import STH` @@ -44,7 +43,7 @@ def get_exchange( min_cost: float = 5.0, limit_threshold: Union[Tuple[str, str], float, None] = None, deal_price: Union[str, Tuple[str], List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Exchange: """get_exchange @@ -52,14 +51,15 @@ def get_exchange( ---------- # exchange related arguments - exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`. + exchange: Exchange + It could be None or any types that are acceptable by `init_instance_by_config`. freq: str frequency of data. start_time: Union[pd.Timestamp, str] closed start time for backtest. end_time: Union[pd.Timestamp, str] closed end time for backtest. - codes: list|str + codes: Union[list, str] list stock_id list or a string of instruments (i.e. all, csi500, sse50) subscribe_fields: list subscribe fields. @@ -151,28 +151,24 @@ def create_account_instance( Postion type. """ if isinstance(account, (int, float)): - pos_kwargs = {"init_cash": account} + init_cash = account + position_dict = {} elif isinstance(account, dict): - init_cash = account["cash"] - del account["cash"] - pos_kwargs = { - "init_cash": init_cash, - "position_dict": account, - } + init_cash = account.pop("cash") + position_dict = account else: - raise ValueError("account must be in (int, float, Position)") + raise ValueError("account must be in (int, float, dict)") - kwargs = { - "init_cash": account, - "benchmark_config": { + return Account( + init_cash=init_cash, + position_dict=position_dict, + pos_type=pos_type, + benchmark_config={ "benchmark": benchmark, "start_time": start_time, "end_time": end_time, }, - "pos_type": pos_type, - } - kwargs.update(pos_kwargs) - return Account(**kwargs) + ) def get_strategy_executor( @@ -181,7 +177,7 @@ def get_strategy_executor( strategy: Union[str, dict, object, Path], executor: Union[str, dict, object, Path], benchmark: str = "SH000300", - account: Union[float, int, Position] = 1e9, + account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", ) -> Tuple[BaseStrategy, BaseExecutor]: @@ -222,7 +218,7 @@ def backtest( strategy: Union[str, dict, object, Path], executor: Union[str, dict, object, Path], benchmark: str = "SH000300", - account: Union[float, int, Position] = 1e9, + account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", ) -> Tuple[PortfolioMetrics, Indicator]: @@ -285,7 +281,7 @@ def collect_data( strategy: Union[str, dict, object, Path], executor: Union[str, dict, object, Path], benchmark: str = "SH000300", - account: Union[float, int, Position] = 1e9, + account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", return_value: dict = None, @@ -339,7 +335,7 @@ def format_decisions( cur_freq = decisions[0].strategy.trade_calendar.get_freq() - res = (cur_freq, []) + res: Tuple[str, list] = (cur_freq, []) last_dec_idx = 0 for i, dec in enumerate(decisions[1:], 1): if dec.strategy.trade_calendar.get_freq() == cur_freq: diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 9d8adddb0..6054ac638 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple, cast import pandas as pd @@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config from .decision import BaseTradeDecision, Order from .exchange import Exchange +from .high_performance_ds import BaseOrderIndicator from .position import BasePosition from .report import Indicator, PortfolioMetrics @@ -104,7 +105,7 @@ class Account: self._pos_type = pos_type self._port_metr_enabled = port_metr_enabled - self.benchmark_config = None # avoid no attribute error + self.benchmark_config: dict = {} # avoid no attribute error self.init_vars(init_cash, position_dict, freq, benchmark_config) def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None: @@ -124,8 +125,8 @@ class Account: self.accum_info = AccumulatedInfo() # 2) following variables are not shared between layers - self.portfolio_metrics = None - self.hist_positions = {} + self.portfolio_metrics: Optional[PortfolioMetrics] = None + self.hist_positions: Dict[pd.Timestamp, BasePosition] = {} self.reset(freq=freq, benchmark_config=benchmark_config) def is_port_metr_enabled(self) -> bool: @@ -171,7 +172,7 @@ class Account: self.reset_report(self.freq, self.benchmark_config) - def get_hist_positions(self) -> dict: + def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]: return self.hist_positions def get_cash(self) -> float: @@ -230,13 +231,15 @@ class Account: """ # update price for stock in the position and the profit from changed_price # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy + assert self.current_position is not None + if not self.current_position.skip_update(): stock_list = self.current_position.get_stock_list() for code in stock_list: # if suspend, no new price to be updated, profit is 0 if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): continue - bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) + bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time)) self.current_position.update_stock_price(stock_id=code, price=bar_close) # update holding day count # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy @@ -249,6 +252,8 @@ class Account: # for the first trade date, account_value - init_cash # self.portfolio_metrics.is_empty() to judge is_first_trade_date # get last_account_value, last_total_cost, last_total_turnover + assert self.portfolio_metrics is not None + if self.portfolio_metrics.is_empty(): last_account_value = self.init_cash last_total_cost = 0 @@ -299,9 +304,9 @@ class Account: trade_exchange: Exchange, atomic: bool, outer_trade_decision: BaseTradeDecision, - trade_info: list = None, - inner_order_indicators: List[Dict[str, pd.Series]] = None, - decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, + trade_info: list = [], + inner_order_indicators: List[BaseOrderIndicator] = [], + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [], indicator_config: dict = {}, ) -> None: """update trade indicators and order indicators in each bar end""" @@ -335,9 +340,9 @@ class Account: trade_exchange: Exchange, atomic: bool, outer_trade_decision: BaseTradeDecision, - trade_info: list = None, - inner_order_indicators: List[Dict[str, pd.Series]] = None, - decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, + trade_info: list = [], + inner_order_indicators: List[BaseOrderIndicator] = [], + decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [], indicator_config: dict = {}, ) -> None: """update account at each trading bar step @@ -398,6 +403,7 @@ class Account: def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]: """get the history portfolio_metrics and positions instance""" if self.is_port_metr_enabled(): + assert self.portfolio_metrics is not None _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe() _positions = self.get_hist_positions() return _portfolio_metrics, _positions diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index c42d6fc9b..e47655069 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast import pandas as pd @@ -36,10 +36,13 @@ def backtest_loop( indicator: Indicator it computes the trading indicator """ - return_value = {} + return_value: dict = {} for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): pass - return return_value.get("portfolio_metrics"), return_value.get("indicator") + + portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics")) + indicator = cast(Indicator, return_value.get("indicator")) + return portfolio_metrics, indicator def collect_data_loop( diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 9a6084214..42e798c6d 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -7,7 +7,7 @@ from abc import abstractmethod from enum import IntEnum # try to fix circular imports when enabling type hints -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union +from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast from qlib.backtest.utils import TradeCalendarManager from qlib.data.data import Cal @@ -24,8 +24,11 @@ import numpy as np import pandas as pd +DecisionType = TypeVar("DecisionType") + + class OrderDir(IntEnum): - # Order direction + # Order direction SELL = 0 BUY = 1 @@ -65,7 +68,7 @@ class Order: # - not tradable: the deal_amount == 0 , factor is None # - the stock is suspended and the entire order fails. No cost for this order # - dealt or partially dealt: deal_amount >= 0 and factor is not None - deal_amount: Optional[float] = None # `deal_amount` is a non-negative value + deal_amount: float = 0.0 # `deal_amount` is a non-negative value factor: Optional[float] = None # TODO: @@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange): return max(val_start, start_time), min(val_end, end_time) -class BaseTradeDecision: +class BaseTradeDecision(Generic[DecisionType]): """ Trade decisions ara made by strategy and executed by executor @@ -316,20 +319,21 @@ class BaseTradeDecision: """ self.strategy = strategy self.start_time, self.end_time = strategy.trade_calendar.get_step_time() - self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading` - if isinstance(trade_range, Tuple): + # upper strategy has no knowledge about the sub executor before `_init_sub_trading` + self.total_step: Optional[int] = None + if isinstance(trade_range, tuple): # for Tuple[int, int] trade_range = IdxTradeRange(*trade_range) - self.trade_range: TradeRange = trade_range + self.trade_range: Optional[TradeRange] = trade_range - def get_decision(self) -> List[object]: + def get_decision(self) -> List[DecisionType]: """ get the **concrete decision** (e.g. execution orders) This will be called by the inner strategy Returns ------- - List[object]: + List[DecisionType: The decision result. Typically it is some orders Example: []: @@ -363,13 +367,13 @@ class BaseTradeDecision: # purpose 2) return self.strategy.update_trade_decision(self, trade_calendar) - def _get_range_limit(self, **kwargs) -> Tuple[int, int]: + def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]: if self.trade_range is not None: - return self.trade_range(trade_calendar=kwargs.get("inner_calendar")) + return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar"))) else: raise NotImplementedError("The decision didn't provide an index range") - def get_range_limit(self, **kwargs) -> Tuple[int, int]: + def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]: """ return the expected step range for limiting the decision execution time Both left and right are **closed** @@ -421,6 +425,7 @@ class BaseTradeDecision: if getattr(self, "total_step", None) is not None: # if `self.update` is called. # Then the _start_idx, _end_idx should be clipped + assert self.total_step is not None if _start_idx < 0 or _end_idx >= self.total_step: logger = get_module_logger("decision") logger.warning( @@ -516,7 +521,7 @@ class BaseTradeDecision: inner_trade_decision.trade_range = self.trade_range -class EmptyTradeDecision(BaseTradeDecision): +class EmptyTradeDecision(BaseTradeDecision[object]): def get_decision(self) -> List[object]: return [] @@ -524,23 +529,24 @@ class EmptyTradeDecision(BaseTradeDecision): return True -class TradeDecisionWO(BaseTradeDecision): +class TradeDecisionWO(BaseTradeDecision[Order]): """ Trade Decision (W)ith (O)rder. Besides, the time_range is also included. """ - def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None): + def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None: super().__init__(strategy, trade_range=trade_range) - self.order_list = order_list + self.order_list = cast(List[Order], order_list) start, end = strategy.trade_calendar.get_step_time() for o in order_list: + assert isinstance(o, Order) if o.start_time is None: o.start_time = start if o.end_time is None: o.end_time = end - def get_decision(self) -> List[object]: + def get_decision(self) -> List[Order]: return self.order_list def __repr__(self) -> str: diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index ba1dd2c0b..7e4210fe7 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast from ..utils.index_data import IndexData @@ -42,7 +42,7 @@ class Exchange: impact_cost: float = 0.0, extra_quote: pd.DataFrame = None, quote_cls: Type[BaseQuote] = NumpyQuote, - **kwargs, + **kwargs: Any, ) -> None: """__init__ :param freq: frequency of data @@ -141,7 +141,7 @@ class Exchange: if limit_threshold is None: if C.region == REG_CN: self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") - elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1: + elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1: if C.region == REG_CN: self.logger.warning(f"limit_threshold may not be set to a reasonable value") @@ -150,7 +150,7 @@ class Exchange: deal_price = "$" + deal_price self.buy_price = self.sell_price = deal_price elif isinstance(deal_price, (tuple, list)): - self.buy_price, self.sell_price = deal_price + self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price) else: raise NotImplementedError(f"This type of input is not supported") @@ -167,10 +167,10 @@ class Exchange: necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} if self.limit_type == self.LT_TP_EXP: + assert isinstance(limit_threshold, tuple) for exp in limit_threshold: necessary_fields.add(exp) - all_fields = necessary_fields | set(vol_lt_fields) - all_fields = list(all_fields | set(subscribe_fields)) + all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields)) self.all_fields = all_fields @@ -249,9 +249,9 @@ class Exchange: LT_FLT = "float" # float LT_NONE = "none" # none - def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str: + def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str: """get limit type""" - if isinstance(limit_threshold, Tuple): + if isinstance(limit_threshold, tuple): return self.LT_TP_EXP elif isinstance(limit_threshold, float): return self.LT_FLT @@ -268,14 +268,16 @@ class Exchange: self.quote_df["limit_sell"] = False elif limit_type == self.LT_TP_EXP: # set limit + limit_threshold = cast(tuple, limit_threshold) self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]] self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]] elif limit_type == self.LT_FLT: + limit_threshold = cast(float, limit_threshold) self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130 @staticmethod - def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]: + def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]: """ preprocess the volume limit. get the fields need to get from qlib. @@ -340,11 +342,11 @@ class Exchange: if direction is None: buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") - return buy_limit or sell_limit + return bool(buy_limit or sell_limit) elif direction == Order.BUY: - return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") + return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")) elif direction == Order.SELL: - return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") + return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")) else: raise ValueError(f"direction {direction} is not supported!") @@ -382,7 +384,7 @@ class Exchange: order: Order, trade_account: Account = None, position: BasePosition = None, - dealt_order_amount: defaultdict = defaultdict(float), + dealt_order_amount: Dict[str, float] = defaultdict(float), ) -> Tuple[float, float, float]: """ Deal order when the actual transaction @@ -426,9 +428,10 @@ class Exchange: stock_id: str, start_time: pd.Timestamp, end_time: pd.Timestamp, + field: str, method: str = "ts_data_last", ) -> Union[None, int, float, bool, IndexData]: - return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`? + return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method) def get_close( self, @@ -444,10 +447,10 @@ class Exchange: stock_id: str, start_time: pd.Timestamp, end_time: pd.Timestamp, - method: str = "sum", + method: Optional[str] = "sum", ) -> float: """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" - return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) + return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)) def get_deal_price( self, @@ -455,7 +458,7 @@ class Exchange: start_time: pd.Timestamp, end_time: pd.Timestamp, direction: OrderDir, - method: str = "ts_data_last", + method: Optional[str] = "ts_data_last", ) -> float: if direction == OrderDir.SELL: pstr = self.sell_price @@ -469,7 +472,7 @@ class Exchange: self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") deal_price = self.get_close(stock_id, start_time, end_time, method) - return deal_price + return cast(float, deal_price) def get_factor( self, @@ -544,7 +547,7 @@ class Exchange: ) return amount_dict - def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float: + def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float: """ Calculate the real adjust deal amount when considering the trading unit :param current_amount: @@ -572,7 +575,7 @@ class Exchange: current_position: dict, start_time: pd.Timestamp, end_time: pd.Timestamp, - ) -> list: + ) -> List[Order]: """ Note: some future information is used in this function Parameter: @@ -681,6 +684,7 @@ class Exchange: factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time) else: raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") + assert factor is not None return factor def get_amount_of_trade_unit( @@ -718,12 +722,12 @@ class Exchange: def round_amount_by_trade_unit( self, - deal_amount, + deal_amount: float, factor: float = None, stock_id: str = None, - start_time=None, - end_time=None, - ): + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, + ) -> float: """Parameter Please refer to the docs of get_amount_of_trade_unit deal_amount : float, adjusted amount @@ -741,7 +745,7 @@ class Exchange: return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount - def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int: + def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]: """parse the capacity limit string and return the actual amount of orders that can be executed. NOTE: this function will change the order.deal_amount **inplace** @@ -753,15 +757,12 @@ class Exchange: dealt_order_amount : dict :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float} """ - if order.direction == Order.BUY: - vol_limit = self.buy_vol_limit - elif order.direction == Order.SELL: - vol_limit = self.sell_vol_limit + vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit if vol_limit is None: return order.deal_amount - vol_limit_num = [] + vol_limit_num: List[float] = [] for limit in vol_limit: assert isinstance(limit, tuple) if limit[0] == "current": @@ -772,7 +773,7 @@ class Exchange: field=limit[1], method="sum", ) - vol_limit_num.append(limit_value) + vol_limit_num.append(cast(float, limit_value)) elif limit[0] == "cum": limit_value = self.quote.get_data( order.stock_id, @@ -790,12 +791,14 @@ class Exchange: if vol_limit_min < orig_deal_amount: self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}") - def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): + return None + + def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float: """return the real order amount after cash limit for buying. Parameters ---------- trade_price : float - position : cash + cash : float cost_ratio : float Return @@ -803,7 +806,7 @@ class Exchange: float the real order amount after cash limit for buying. """ - max_trade_amount = 0 + max_trade_amount = 0.0 if cash >= self.min_cost: # critical_price means the stock transaction price when the service fee is equal to min_cost. critical_price = self.min_cost / cost_ratio + self.min_cost @@ -897,7 +900,7 @@ class Exchange: order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) else: - raise NotImplementedError("order type {} error".format(order.type)) + raise NotImplementedError("order direction {} error".format(order.direction)) trade_val = order.deal_amount * trade_price trade_cost = max(trade_val * cost_ratio, self.min_cost) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 2105471e1..ef507e1a0 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -4,7 +4,7 @@ import copy from abc import abstractmethod from collections import defaultdict from types import GeneratorType -from typing import Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Tuple, Union, cast import pandas as pd @@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy from ..utils import init_instance_by_config from .decision import BaseTradeDecision, Order from .exchange import Exchange -from .utils import ( - BaseInfrastructure, - CommonInfrastructure, - LevelInfrastructure, - TradeCalendarManager, - get_start_end_idx, -) +from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx class BaseExecutor: @@ -39,8 +33,8 @@ class BaseExecutor: track_data: bool = False, trade_exchange: Exchange = None, common_infra: CommonInfrastructure = None, - settle_type=BasePosition.ST_NO, # TODO: add typehint - **kwargs, + settle_type: str = BasePosition.ST_NO, + **kwargs: Any, ) -> None: """ Parameters @@ -127,10 +121,10 @@ class BaseExecutor: get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}") # record deal order amount in one day - self.dealt_order_amount = defaultdict(float) + self.dealt_order_amount: Dict[str, float] = defaultdict(float) self.deal_day = None - def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None: + def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None: """ reset infrastructure for trading - reset trade_account @@ -141,14 +135,15 @@ class BaseExecutor: self.common_infra.update(common_infra) if common_infra.has("trade_account"): - if copy_trade_account: - # NOTE: there is a trick in the code. - # shallow copy is used instead of deepcopy. - # 1. So positions are shared - # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics) - self.trade_account: Account = copy.copy(common_infra.get("trade_account")) - else: - self.trade_account: Account = common_infra.get("trade_account") + # NOTE: there is a trick in the code. + # shallow copy is used instead of deepcopy. + # 1. So positions are shared + # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics) + self.trade_account: Account = ( + copy.copy(common_infra.get("trade_account")) + if copy_trade_account + else common_infra.get("trade_account") + ) self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics) @property @@ -164,7 +159,7 @@ class BaseExecutor: """ return self.level_infra.get("trade_calendar") - def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None: + def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None: """ - reset `start_time` and `end_time`, used in trade calendar - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -200,20 +195,17 @@ class BaseExecutor: execute_result : List[object] the executed result for trade decision """ - return_value = {} + return_value: dict = {} for _decision in self.collect_data(trade_decision, return_value=return_value, level=level): pass - return return_value.get("execute_result") + return cast(list, return_value.get("execute_result")) @abstractmethod def _collect_data( self, trade_decision: BaseTradeDecision, level: int = 0, - ) -> Union[ - Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]], - Tuple[List[object], dict], - ]: + ) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]: """ Please refer to the doc of collect_data The only difference between `_collect_data` and `collect_data` is that some common steps are moved into @@ -235,7 +227,7 @@ class BaseExecutor: trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0, - ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]: + ) -> Generator[Any, Any, List[object]]: """Generator for collecting the trade decision data for rl training his function will make a step forward @@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor): skip_empty_decision: bool = True, align_range_limit: bool = True, common_infra: CommonInfrastructure = None, - **kwargs, + **kwargs: Any, ) -> None: """ Parameters @@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor): self, trade_decision: BaseTradeDecision, level: int = 0, - ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]: + ) -> Generator[Any, Any, Tuple[List[object], dict]]: execute_result = [] inner_order_indicators = [] decision_list = [] @@ -493,7 +485,7 @@ class NestedExecutor(BaseExecutor): the execution result of inner task """ - def get_all_executors(self) -> List[object]: + def get_all_executors(self) -> List[BaseExecutor]: """get all executors, including self and inner_executor.get_all_executors()""" return [self, *self.inner_executor.get_all_executors()] @@ -536,7 +528,7 @@ class SimulatorExecutor(BaseExecutor): track_data: bool = False, common_infra: CommonInfrastructure = None, trade_type: str = TT_SERIAL, - **kwargs, + **kwargs: Any, ) -> None: """ Parameters @@ -598,7 +590,7 @@ class SimulatorExecutor(BaseExecutor): def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: trade_start_time, _ = self.trade_calendar.get_step_time() - execute_result = [] + execute_result: list = [] for order in self._get_order_iterator(trade_decision): # execute the order. diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 8cfa9bacc..dc467bd59 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + import inspect import logging from collections import OrderedDict from functools import lru_cache -from typing import Callable, Dict, Iterable, List, Text, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast import numpy as np import pandas as pd @@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value class BaseQuote: - def __init__(self, quote_df: pd.DataFrame, freq): + def __init__(self, quote_df: pd.DataFrame, freq: str) -> None: self.logger = get_module_logger("online operator", level=logging.INFO) def get_all_stock(self) -> Iterable: @@ -39,7 +41,7 @@ class BaseQuote: start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], field: Union[str], - method: Union[str, None] = None, + method: Optional[str] = None, ) -> Union[None, int, float, bool, IndexData]: """get the specific field of stock data during start time and end_time, and apply method to the data. @@ -99,7 +101,7 @@ class BaseQuote: class PandasQuote(BaseQuote): - def __init__(self, quote_df: pd.DataFrame, freq): + def __init__(self, quote_df: pd.DataFrame, freq: str) -> None: super().__init__(quote_df=quote_df, freq=freq) quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): @@ -124,7 +126,7 @@ class PandasQuote(BaseQuote): class NumpyQuote(BaseQuote): - def __init__(self, quote_df: pd.DataFrame, freq, region="cn"): + def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None: """NumpyQuote Parameters @@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote): data = self._agg_data(data, method) return data - def _agg_data(self, data: IndexData, method): + @staticmethod + def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]: """Agg data by specific method.""" # FIXME: why not call the method of data directly? if method == "sum": @@ -224,31 +227,31 @@ class BaseSingleMetric: """ raise NotImplementedError(f"Please implement the `__init__` method") - def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__add__` method") - def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: return self + other - def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__sub__` method") - def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__rsub__` method") - def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__mul__` method") - def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__truediv__` method") - def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __eq__(self, other: object) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__eq__` method") - def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__gt__` method") - def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": + def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__lt__` method") def __len__(self) -> int: @@ -265,7 +268,7 @@ class BaseSingleMetric: raise NotImplementedError(f"Please implement the `count` method") - def abs(self) -> "BaseSingleMetric": + def abs(self) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `abs` method") @property @@ -274,18 +277,18 @@ class BaseSingleMetric: raise NotImplementedError(f"Please implement the `empty` method") - def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric": + def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric: """Replace np.NaN with fill_value in two metrics and add them.""" raise NotImplementedError(f"Please implement the `add` method") - def replace(self, replace_dict: dict) -> "BaseSingleMetric": + def replace(self, replace_dict: dict) -> BaseSingleMetric: """Replace the value of metric according to replace_dict.""" raise NotImplementedError(f"Please implement the `replace` method") - def apply(self, func: dict) -> "BaseSingleMetric": - """Replace the value of metric with func(metric). + def apply(self, func: Callable) -> BaseSingleMetric: + """Replace the value of metric with func (metric). Currently, the func is only qlib/backtest/order/Order.parse_dir. """ @@ -304,11 +307,11 @@ class BaseOrderIndicator: to inherit the BaseSingleMetric. """ - def __init__(self, data): - self.data = data + def __init__(self): + self.data = {} # will be created in the subclass self.logger = get_module_logger("online operator") - def assign(self, col: str, metric: Union[dict, pd.Series]): + def assign(self, col: str, metric: Union[dict, pd.Series]) -> None: """assign one metric. Parameters @@ -328,7 +331,7 @@ class BaseOrderIndicator: raise NotImplementedError(f"Please implement the 'assign' method") - def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]: + def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]: """compute new metric with existing metrics. Parameters @@ -352,6 +355,7 @@ class BaseOrderIndicator: tmp_metric = func(**func_kwargs) if new_col is not None: self.data[new_col] = tmp_metric + return None else: return tmp_metric @@ -372,7 +376,7 @@ class BaseOrderIndicator: raise NotImplementedError(f"Please implement the 'get_metric_series' method") - def get_index_data(self, metric) -> SingleData: + def get_index_data(self, metric: str) -> SingleData: """get one metric with the format of SingleData Parameters @@ -389,7 +393,12 @@ class BaseOrderIndicator: raise NotImplementedError(f"Please implement the 'get_index_data' method") @staticmethod - def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None): + def sum_all_indicators( + order_indicator: BaseOrderIndicator, + indicators: List[BaseOrderIndicator], + metrics: Union[str, List[str]], + fill_value: float = 0, + ) -> None: """sum indicators with the same metrics. and assign to the order_indicator(BaseOrderIndicator). NOTE: indicators could be a empty list when orders in lower level all fail. @@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric): def index(self): return list(self.metric.index) - def add(self, other, fill_value=None): + def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric: + other = cast(PandasSingleMetric, other) return self.__class__(self.metric.add(other.metric, fill_value=fill_value)) - def replace(self, replace_dict: dict): + def replace(self, replace_dict: dict) -> PandasSingleMetric: return self.__class__(self.metric.replace(replace_dict)) - def apply(self, func: Callable): + def apply(self, func: Callable) -> PandasSingleMetric: return self.__class__(self.metric.apply(func)) - def reindex(self, index, fill_value): + def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric: return self.__class__(self.metric.reindex(index, fill_value=fill_value)) def __repr__(self): @@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator): Str is the name of metric. """ - def __init__(self): + def __init__(self) -> None: + super(PandasOrderIndicator, self).__init__() self.data: Dict[str, PandasSingleMetric] = OrderedDict() - def assign(self, col: str, metric: Union[dict, pd.Series]): + def assign(self, col: str, metric: Union[dict, pd.Series]) -> None: self.data[col] = PandasSingleMetric(metric) - def get_index_data(self, metric): + def get_index_data(self, metric: str) -> SingleData: if metric in self.data: return idd.SingleData(self.data[metric].metric) else: @@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator): return {k: v.metric for k, v in self.data.items()} @staticmethod - def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): + def sum_all_indicators( + order_indicator: BaseOrderIndicator, + indicators: List[BaseOrderIndicator], + metrics: Union[str, List[str]], + fill_value: float = 0, + ) -> None: if isinstance(metrics, str): metrics = [metrics] for metric in metrics: @@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator): Str is the name of metric. """ - def __init__(self): + def __init__(self) -> None: + super(NumpyOrderIndicator, self).__init__() self.data: Dict[str, SingleData] = OrderedDict() - def assign(self, col: str, metric: dict): + def assign(self, col: str, metric: dict) -> None: self.data[col] = idd.SingleData(metric) - def get_index_data(self, metric): + def get_index_data(self, metric: str) -> SingleData: if metric in self.data: return self.data[metric] else: @@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator): return tmp_metric_dict @staticmethod - def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): + def sum_all_indicators( + order_indicator: BaseOrderIndicator, + indicators: List[BaseOrderIndicator], + metrics: Union[str, List[str]], + fill_value: float = 0, + ) -> None: # get all index(stock_id) - stocks = set() + stock_set: set = set() for indicator in indicators: # set(np.ndarray.tolist()) is faster than set(np.ndarray) - stocks = stocks | set(indicator.data[metrics[0]].index.tolist()) - stocks = list(stocks) - stocks.sort() + stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist()) + stocks = sorted(list(stock_set)) # add metric by index if isinstance(metrics, str): diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 06218a67d..ea6b7c57b 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -3,7 +3,7 @@ from datetime import timedelta -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import pandas as pd @@ -18,9 +18,9 @@ class BasePosition: Please refer to the `Position` class for the position """ - def __init__(self, *args, cash: float = 0.0, **kwargs) -> None: + def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None: self._settle_type = self.ST_NO - self.position = {} + self.position: dict = {} def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: pass @@ -96,13 +96,13 @@ class BasePosition: def calculate_value(self) -> float: raise NotImplementedError(f"Please implement the `calculate_value` method") - def get_stock_list(self) -> List: + def get_stock_list(self) -> List[str]: """ Get the list of stocks in the position. """ raise NotImplementedError(f"Please implement the `get_stock_list` method") - def get_stock_price(self, code) -> float: + def get_stock_price(self, code: str) -> float: """ get the latest price of the stock @@ -113,7 +113,7 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `get_stock_price` method") - def get_stock_amount(self, code) -> float: + def get_stock_amount(self, code: str) -> float: """ get the amount of the stock @@ -144,7 +144,7 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `get_cash` method") - def get_stock_amount_dict(self) -> Dict: + def get_stock_amount_dict(self) -> dict: """ generate stock amount dict {stock_id : amount of stock} @@ -155,7 +155,7 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method") - def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: + def get_stock_weight_dict(self, only_stock: bool = False) -> dict: """ generate stock weight dict {stock_id : value weight of stock in the position} it is meaningful in the beginning or the end of each trade step @@ -174,7 +174,7 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method") - def add_count_all(self, bar) -> None: + def add_count_all(self, bar: str) -> None: """ Will be called at the end of each bar on each level @@ -195,7 +195,7 @@ class BasePosition: raise NotImplementedError(f"Please implement the `add_count_all` method") ST_CASH = "cash" - ST_NO = None + ST_NO = "None" # String is more typehint friendly than None def settle_start(self, settle_type: str) -> None: """ @@ -220,10 +220,10 @@ class BasePosition: """ raise NotImplementedError(f"Please implement the `settle_commit` method") - def __str__(self): + def __str__(self) -> str: return self.__dict__.__str__() - def __repr__(self): + def __repr__(self) -> str: return self.__dict__.__repr__() @@ -532,7 +532,7 @@ class InfPosition(BasePosition): def calculate_value(self) -> float: raise NotImplementedError(f"InfPosition doesn't support calculating value") - def get_stock_list(self) -> list: + def get_stock_list(self) -> List[str]: raise NotImplementedError(f"InfPosition doesn't support stock list position") def get_stock_price(self, code: str) -> float: @@ -545,10 +545,10 @@ class InfPosition(BasePosition): def get_cash(self, include_settle: bool = False) -> float: return np.inf - def get_stock_amount_dict(self) -> Dict: + def get_stock_amount_dict(self) -> dict: raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict") - def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: + def get_stock_weight_dict(self, only_stock: bool = False) -> dict: raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict") def add_count_all(self, bar: str) -> None: diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 77e43c8e7..b8aa8273c 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -4,7 +4,7 @@ import pathlib from collections import OrderedDict -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast import numpy as np import pandas as pd @@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data -from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric +from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator class PortfolioMetrics: @@ -38,7 +38,7 @@ class PortfolioMetrics: update report """ - def __init__(self, freq: str = "day", benchmark_config: dict = {}): + def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None: """ Parameters ---------- @@ -49,13 +49,17 @@ class PortfolioMetrics: - benchmark : Union[str, list, pd.Series] - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. example: - print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) + print( + D.features(D.instruments('csi500'), + ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head() + ) 2017-01-04 0.011693 2017-01-05 0.000721 2017-01-06 -0.004322 2017-01-09 0.006874 2017-01-10 -0.003350 - - If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. + - If `benchmark` is list, will use the daily average change of the stock pool in the list as the + 'bench'. - If `benchmark` is str, will use the daily change as the 'bench'. benchmark code, default is SH000300 CSI300 - start_time : Union[str, pd.Timestamp], optional @@ -70,25 +74,26 @@ class PortfolioMetrics: self.init_vars() self.init_bench(freq=freq, benchmark_config=benchmark_config) - def init_vars(self): - self.accounts = OrderedDict() # account position value for each trade time - self.returns = OrderedDict() # daily return rate for each trade time - self.total_turnovers = OrderedDict() # total turnover for each trade time - self.turnovers = OrderedDict() # turnover for each trade time - self.total_costs = OrderedDict() # total trade cost for each trade time - self.costs = OrderedDict() # trade cost rate for each trade time - self.values = OrderedDict() # value for each trade time - self.cashes = OrderedDict() - self.benches = OrderedDict() - self.latest_pm_time = None # pd.TimeStamp + def init_vars(self) -> None: + self.accounts: dict = OrderedDict() # account position value for each trade time + self.returns: dict = OrderedDict() # daily return rate for each trade time + self.total_turnovers: dict = OrderedDict() # total turnover for each trade time + self.turnovers: dict = OrderedDict() # turnover for each trade time + self.total_costs: dict = OrderedDict() # total trade cost for each trade time + self.costs: dict = OrderedDict() # trade cost rate for each trade time + self.values: dict = OrderedDict() # value for each trade time + self.cashes: dict = OrderedDict() + self.benches: dict = OrderedDict() + self.latest_pm_time: Optional[pd.TimeStamp] = None - def init_bench(self, freq=None, benchmark_config=None): + def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None: if freq is not None: self.freq = freq self.benchmark_config = benchmark_config self.bench = self._cal_benchmark(self.benchmark_config, self.freq) - def _cal_benchmark(self, benchmark_config, freq): + @staticmethod + def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]: if benchmark_config is None: return None benchmark = benchmark_config.get("benchmark", CSI300_BENCH) @@ -110,7 +115,12 @@ class PortfolioMetrics: raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) - def _sample_benchmark(self, bench, trade_start_time, trade_end_time): + def _sample_benchmark( + self, + bench: pd.Series, + trade_start_time: Union[str, pd.Timestamp], + trade_end_time: Union[str, pd.Timestamp], + ) -> Optional[float]: if self.bench is None: return None @@ -120,35 +130,35 @@ class PortfolioMetrics: _ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change) return 0.0 if _ret is None else _ret - 1 - def is_empty(self): + def is_empty(self) -> bool: return len(self.accounts) == 0 - def get_latest_date(self): + def get_latest_date(self) -> pd.Timestamp: return self.latest_pm_time - def get_latest_account_value(self): + def get_latest_account_value(self) -> float: return self.accounts[self.latest_pm_time] - def get_latest_total_cost(self): + def get_latest_total_cost(self) -> Any: return self.total_costs[self.latest_pm_time] - def get_latest_total_turnover(self): + def get_latest_total_turnover(self) -> Any: return self.total_turnovers[self.latest_pm_time] def update_portfolio_metrics_record( self, - trade_start_time=None, - trade_end_time=None, - account_value=None, - cash=None, - return_rate=None, - total_turnover=None, - turnover_rate=None, - total_cost=None, - cost_rate=None, - stock_value=None, - bench_value=None, - ): + trade_start_time: Union[str, pd.Timestamp] = None, + trade_end_time: Union[str, pd.Timestamp] = None, + account_value: float = None, + cash: float = None, + return_rate: float = None, + total_turnover: float = None, + turnover_rate: float = None, + total_cost: float = None, + cost_rate: float = None, + stock_value: float = None, + bench_value: float = None, + ) -> None: # check data if None in [ trade_start_time, @@ -185,7 +195,7 @@ class PortfolioMetrics: self.latest_pm_time = trade_start_time # finish pm update in each step - def generate_portfolio_metrics_dataframe(self): + def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame: pm = pd.DataFrame() pm["account"] = pd.Series(self.accounts) pm["return"] = pd.Series(self.returns) @@ -199,19 +209,18 @@ class PortfolioMetrics: pm.index.name = "datetime" return pm - def save_portfolio_metrics(self, path): + def save_portfolio_metrics(self, path: str) -> None: r = self.generate_portfolio_metrics_dataframe() r.to_csv(path) - def load_portfolio_metrics(self, path): + def load_portfolio_metrics(self, path: str) -> None: """load pm from a file should have format like columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench'] :param path: str/ pathlib.Path() """ - path = pathlib.Path(path) - with path.open("rb") as f: + with pathlib.Path(path).open("rb") as f: r = pd.read_csv(f, index_col=0) r.index = pd.DatetimeIndex(r.index) @@ -261,30 +270,30 @@ class Indicator: """ - def __init__(self, order_indicator_cls=NumpyOrderIndicator): + def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None: self.order_indicator_cls = order_indicator_cls # order indicator is metrics for a single order for a specific step - self.order_indicator_his = OrderedDict() + self.order_indicator_his: dict = OrderedDict() self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() # trade indicator is metrics for all orders for a specific step - self.trade_indicator_his = OrderedDict() - self.trade_indicator: Dict[str, float] = OrderedDict() + self.trade_indicator_his: dict = OrderedDict() + self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict() self._trade_calendar = None # def reset(self, trade_calendar: TradeCalendarManager): - def reset(self): - self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() + def reset(self) -> None: + self.order_indicator = self.order_indicator_cls() self.trade_indicator = OrderedDict() # self._trade_calendar = trade_calendar - def record(self, trade_start_time): + def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None: self.order_indicator_his[trade_start_time] = self.get_order_indicator() self.trade_indicator_his[trade_start_time] = self.get_trade_indicator() - def _update_order_trade_info(self, trade_info: list): + def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None: amount = dict() deal_amount = dict() trade_price = dict() @@ -313,7 +322,7 @@ class Indicator: self.order_indicator.assign("trade_dir", trade_dir) self.order_indicator.assign("pa", pa) - def _update_order_fulfill_rate(self): + def _update_order_fulfill_rate(self) -> None: def func(deal_amount, amount): # deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0. tmp_deal_amount = deal_amount.reindex(amount.index, 0) @@ -322,11 +331,11 @@ class Indicator: self.order_indicator.transfer(func, "ffr") - def update_order_indicators(self, trade_info: list): + def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None: self._update_order_trade_info(trade_info=trade_info) self._update_order_fulfill_rate() - def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]): + def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None: # calculate total trade amount with each inner order indicator. def trade_amount_func(deal_amount, trade_price): return deal_amount * trade_price @@ -355,9 +364,9 @@ class Indicator: self.order_indicator.transfer(func_apply, "trade_dir") - def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision): + def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None: # NOTE: these indicator is designed for order execution, so the - decision: List[Order] = outer_trade_decision.get_decision() + decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision()) if len(decision) == 0: self.order_indicator.assign("amount", {}) else: @@ -372,7 +381,7 @@ class Indicator: decision: BaseTradeDecision, trade_exchange: Exchange, pa_config: dict = {}, - ): + ) -> Tuple[Optional[float], Optional[float]]: """ Get the base volume and price information All the base price values are rooted from this function @@ -412,31 +421,35 @@ class Indicator: # NOTE: there are some zeros in the trading price. These cases are known meaningless # for aligning the previous logic, remove it. # remove zero and negative values. - price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)] + assert isinstance(price_s, idd.SingleData) + price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)] # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8 # ~(np.NaN < 1e-8) -> ~(False) -> True + assert isinstance(price_s, idd.SingleData) if agg == "vwap": volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) if isinstance(volume_s, (int, float, np.number)): volume_s = idd.SingleData(volume_s, [trade_start_time]) + assert isinstance(volume_s, idd.SingleData) volume_s = volume_s.reindex(price_s.index) elif agg == "twap": volume_s = idd.SingleData(1, price_s.index) else: raise NotImplementedError(f"This type of input is not supported") + assert isinstance(volume_s, idd.SingleData) base_volume = volume_s.sum() base_price = (price_s * volume_s).sum() / base_volume return base_price, base_volume def _agg_base_price( self, - inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]], + inner_order_indicators: List[BaseOrderIndicator], decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], trade_exchange: Exchange, pa_config: dict = {}, - ): + ) -> None: """ # NOTE:!!!! # Strong assumption!!!!!! @@ -444,7 +457,7 @@ class Indicator: Parameters ---------- - inner_order_indicators : List[Dict[str, pd.Series]] + inner_order_indicators : List[BaseOrderIndicator] the indicators of account of inner executor decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], a list of decisions according to inner_order_indicators @@ -489,14 +502,17 @@ class Indicator: bv_new = idd.SingleData(bv_new) bp_all.append(bp_new) bv_all.append(bv_new) - bp_all = idd.concat(bp_all, axis=1) - bv_all = idd.concat(bv_all, axis=1) + bp_all_multi_data = idd.concat(bp_all, axis=1) + bv_all_multi_data = idd.concat(bv_all, axis=1) - base_volume = bv_all.sum(axis=1) + base_volume = bv_all_multi_data.sum(axis=1) self.order_indicator.assign("base_volume", base_volume.to_dict()) - self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict()) + self.order_indicator.assign( + "base_price", + ((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(), + ) - def _agg_order_price_advantage(self): + def _agg_order_price_advantage(self) -> None: def if_empty_func(trade_price): return trade_price.empty @@ -513,12 +529,12 @@ class Indicator: def agg_order_indicators( self, - inner_order_indicators: List[Dict[str, pd.Series]], + inner_order_indicators: List[BaseOrderIndicator], decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], outer_trade_decision: BaseTradeDecision, trade_exchange: Exchange, - indicator_config={}, - ): + indicator_config: dict = {}, + ) -> None: self._agg_order_trade_info(inner_order_indicators) self._update_trade_amount(outer_trade_decision) self._update_order_fulfill_rate() @@ -526,71 +542,66 @@ class Indicator: self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO self._agg_order_price_advantage() - def _cal_trade_fulfill_rate(self, method="mean"): + def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]: if method == "mean": - - def func(ffr): - return ffr.mean() - + return self.order_indicator.transfer( + lambda ffr: ffr.mean(), + ) elif method == "amount_weighted": - - def func(ffr, deal_amount): - return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()) - + return self.order_indicator.transfer( + lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()), + ) elif method == "value_weighted": - - def func(ffr, trade_value): - return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()) - + return self.order_indicator.transfer( + lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()), + ) else: raise ValueError(f"method {method} is not supported!") - return self.order_indicator.transfer(func) - def _cal_trade_price_advantage(self, method="mean"): + def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]: if method == "mean": - - def func(pa): - return pa.mean() - + return self.order_indicator.transfer(lambda pa: pa.mean()) elif method == "amount_weighted": - - def func(pa, deal_amount): - return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()) - + return self.order_indicator.transfer( + lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()), + ) elif method == "value_weighted": - - def func(pa, trade_value): - return (pa * trade_value.abs()).sum() / (trade_value.abs().sum()) - + return self.order_indicator.transfer( + lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()), + ) else: raise ValueError(f"method {method} is not supported!") - return self.order_indicator.transfer(func) - def _cal_trade_positive_rate(self): + def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]: def func(pa): return (pa > 0).sum() / pa.count() return self.order_indicator.transfer(func) - def _cal_deal_amount(self): + def _cal_deal_amount(self) -> Optional[BaseSingleMetric]: def func(deal_amount): return deal_amount.abs().sum() return self.order_indicator.transfer(func) - def _cal_trade_value(self): + def _cal_trade_value(self) -> Optional[BaseSingleMetric]: def func(trade_value): return trade_value.abs().sum() return self.order_indicator.transfer(func) - def _cal_trade_order_count(self): + def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]: def func(amount): return amount.count() return self.order_indicator.transfer(func) - def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}): + def cal_trade_indicators( + self, + trade_start_time: Union[str, pd.Timestamp], + freq: str, + indicator_config: dict = {}, + ) -> None: show_indicator = indicator_config.get("show_indicator", False) ffr_config = indicator_config.get("ffr_config", {}) pa_config = indicator_config.get("pa_config", {}) @@ -608,22 +619,22 @@ class Indicator: self.trade_indicator["count"] = order_count if show_indicator: print( - "[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format( + "[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format( freq, - trade_start_time, + trade_start_time + if isinstance(trade_start_time, str) + else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"), fulfill_rate, price_advantage, positive_rate, ), ) - def get_order_indicator(self, raw: bool = True): - if raw: - return self.order_indicator - return self.order_indicator.to_series() + def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]: + return self.order_indicator if raw else self.order_indicator.to_series() - def get_trade_indicator(self): + def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]: return self.trade_indicator - def generate_trade_indicators_dataframe(self): + def generate_trade_indicators_dataframe(self) -> pd.DataFrame: return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index") diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py index 4615a89c0..cedc9bb17 100644 --- a/qlib/backtest/signal.py +++ b/qlib/backtest/signal.py @@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]: + def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]: """ get the signal at the end of the decision step(from `start_time` to `end_time`) @@ -39,13 +39,14 @@ class SignalWCache(Signal): SignalWCache will store the prepared signal as a attribute and give the according signal based on input query """ - def __init__(self, signal: Union[pd.Series, pd.DataFrame]): + def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None: """ Parameters ---------- signal : Union[pd.Series, pd.DataFrame] - The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted) + The expected format of the signal is like the data below (the order of index is not important and can be + automatically adjusted) instrument datetime SH600000 2008-01-02 0.079704 @@ -56,8 +57,8 @@ class SignalWCache(Signal): """ self.signal_cache = convert_index_format(signal, level="datetime") - def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]: - # the frequency of the signal may not algin with the decision frequency of strategy + def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]: + # the frequency of the signal may not align with the decision frequency of strategy # so resampling from the data is necessary # the latest signal leverage more recent data and therefore is used in trading. signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") @@ -65,7 +66,7 @@ class SignalWCache(Signal): class ModelSignal(SignalWCache): - def __init__(self, model: BaseModel, dataset: Dataset): + def __init__(self, model: BaseModel, dataset: Dataset) -> None: self.model = model self.dataset = dataset pred_scores = self.model.predict(dataset) @@ -73,7 +74,7 @@ class ModelSignal(SignalWCache): pred_scores = pred_scores.iloc[:, 0] super().__init__(pred_scores) - def _update_model(self): + def _update_model(self) -> None: """ When using online data, update model in each bar as the following steps: - update dataset with online data, the dataset should support online update diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 2077986bc..db35dc482 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -149,6 +149,8 @@ class TradeCalendarManager: Tuple[int, int]: """ # potential performance issue + assert self.level_infra is not None + day_start = pd.Timestamp(self.start_time.date()) day_end = epsilon_change(day_start + pd.Timedelta(days=1)) freq = self.level_infra.get("common_infra").get("trade_exchange").freq @@ -182,8 +184,8 @@ class TradeCalendarManager: Tuple[int, int]: the index of the range. **the left and right are closed** """ - left = bisect.bisect_right(self._calendar, start_time) - 1 - right = bisect.bisect_right(self._calendar, end_time) - 1 + left = bisect.bisect_right(list(self._calendar), start_time) - 1 + right = bisect.bisect_right(list(self._calendar), end_time) - 1 left -= self.start_index right -= self.start_index @@ -201,14 +203,14 @@ class TradeCalendarManager: class BaseInfrastructure: - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: self.reset_infra(**kwargs) @abstractmethod def get_support_infra(self) -> Set[str]: raise NotImplementedError("`get_support_infra` is not implemented!") - def reset_infra(self, **kwargs) -> None: + def reset_infra(self, **kwargs: Any) -> None: support_infra = self.get_support_infra() for k, v in kwargs.items(): if k in support_infra: diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index cb703cc20..2901a40ea 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -339,7 +339,7 @@ def long_short_backtest( for stock in long_stocks: if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) if np.isnan(profit): long_profit.append(0) else: @@ -348,17 +348,17 @@ def long_short_backtest( for stock in short_stocks: if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) if np.isnan(profit): short_profit.append(0) else: - short_profit.append(-profit) + short_profit.append(profit * -1) for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)): # exclude the suspend stock if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) if np.isnan(profit): all_profit.append(0) else: diff --git a/qlib/data/data.py b/qlib/data/data.py index 08320cae5..a6b1ce19a 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC): _, _, si, ei = self.locate_index(start_time, end_time, freq, future) return _calendar[si : ei + 1] - def locate_index(self, start_time, end_time, freq, future=False): + def locate_index( + self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False + ): """Locate the start time index and end time index in a calendar under certain frequency. Parameters ---------- - start_time : str + start_time : pd.Timestamp start of the time range. - end_time : str + end_time : pd.Timestamp end of the time range. freq : str time frequency, available: year/quarter/month/week/day. diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 6cf386801..e2d0382b1 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -248,7 +248,7 @@ def load_orders( Order( row["instrument"], row["amount"], - int(row["order_type"]), + OrderDir(int(row["order_type"])), row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), ) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 3ca8a8bd0..37998a4af 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Generator, Optional if TYPE_CHECKING: from qlib.backtest.exchange import Exchange @@ -122,7 +122,10 @@ class BaseStrategy: self.outer_trade_decision = outer_trade_decision @abstractmethod - def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + def generate_trade_decision( + self, + execute_result: list = None, + ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: """Generate trade decision in each trading bar Parameters diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index a6320362c..9f1aab4fe 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -9,6 +9,8 @@ Motivation of index_data `index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors. """ +from __future__ import annotations + from typing import Dict, Tuple, Union, Callable, List import bisect @@ -16,7 +18,7 @@ import numpy as np import pandas as pd -def concat(data_list: Union["SingleData"], axis=0) -> "MultiData": +def concat(data_list: Union[SingleData], axis=0) -> MultiData: """concat all SingleData by index. TODO: now just for SingleData. @@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData": raise ValueError(f"axis must be 0 or 1") -def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData": +def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData: """concat all SingleData by new index. Parameters @@ -554,7 +556,7 @@ class SingleData(IndexData): f"The indexes of self and other do not meet the requirements of the four arithmetic operations" ) - def reindex(self, index: Index, fill_value=np.NaN): + def reindex(self, index: Index, fill_value=np.NaN) -> SingleData: """reindex data and fill the missing value with np.NaN. Parameters @@ -580,7 +582,7 @@ class SingleData(IndexData): pass return SingleData(tmp_data, index) - def add(self, other: "SingleData", fill_value=0): + def add(self, other: SingleData, fill_value=0): # TODO: add and __add__ are a little confusing. # This could be a more general common_index = self.index | other.index