1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Backtest Mypy (#1130)

* Done

* Fix test errors

* Revert profit_attribution.py

* Minor

* A minor update on collect_data type hint

* Resolve PR comments

* Use black to format code

* Fix CI errors
This commit is contained in:
Huoran Li
2022-06-28 22:16:46 +08:00
committed by GitHub
parent 9bf3423a64
commit 23c657a7a2
17 changed files with 363 additions and 315 deletions

View File

@@ -1,6 +1,6 @@
[mypy] [mypy]
exclude = (?x)( exclude = (?x)(
^qlib/backtest ^qlib/backtest/high_performance_ds\.py$
| ^qlib/contrib | ^qlib/contrib
| ^qlib/data | ^qlib/data
| ^qlib/model | ^qlib/model

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import copy import copy
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
import pandas as pd import pandas as pd
@@ -23,7 +23,6 @@ from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop from .backtest import backtest_loop, collect_data_loop
from .decision import Order from .decision import Order
from .exchange import Exchange from .exchange import Exchange
from .position import Position
from .utils import CommonInfrastructure from .utils import CommonInfrastructure
# make import more user-friendly by adding `from qlib.backtest import STH` # make import more user-friendly by adding `from qlib.backtest import STH`
@@ -44,7 +43,7 @@ def get_exchange(
min_cost: float = 5.0, min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None, limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None, deal_price: Union[str, Tuple[str], List[str]] = None,
**kwargs, **kwargs: Any,
) -> Exchange: ) -> Exchange:
"""get_exchange """get_exchange
@@ -52,14 +51,15 @@ def get_exchange(
---------- ----------
# exchange related arguments # exchange related arguments
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`. exchange: Exchange
It could be None or any types that are acceptable by `init_instance_by_config`.
freq: str freq: str
frequency of data. frequency of data.
start_time: Union[pd.Timestamp, str] start_time: Union[pd.Timestamp, str]
closed start time for backtest. closed start time for backtest.
end_time: Union[pd.Timestamp, str] end_time: Union[pd.Timestamp, str]
closed end time for backtest. closed end time for backtest.
codes: list|str codes: Union[list, str]
list stock_id list or a string of instruments (i.e. all, csi500, sse50) list stock_id list or a string of instruments (i.e. all, csi500, sse50)
subscribe_fields: list subscribe_fields: list
subscribe fields. subscribe fields.
@@ -151,28 +151,24 @@ def create_account_instance(
Postion type. Postion type.
""" """
if isinstance(account, (int, float)): if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account} init_cash = account
position_dict = {}
elif isinstance(account, dict): elif isinstance(account, dict):
init_cash = account["cash"] init_cash = account.pop("cash")
del account["cash"] position_dict = account
pos_kwargs = {
"init_cash": init_cash,
"position_dict": account,
}
else: else:
raise ValueError("account must be in (int, float, Position)") raise ValueError("account must be in (int, float, dict)")
kwargs = { return Account(
"init_cash": account, init_cash=init_cash,
"benchmark_config": { position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
"benchmark": benchmark, "benchmark": benchmark,
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
}, },
"pos_type": pos_type, )
}
kwargs.update(pos_kwargs)
return Account(**kwargs)
def get_strategy_executor( def get_strategy_executor(
@@ -181,7 +177,7 @@ def get_strategy_executor(
strategy: Union[str, dict, object, Path], strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path], executor: Union[str, dict, object, Path],
benchmark: str = "SH000300", benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9, account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {}, exchange_kwargs: dict = {},
pos_type: str = "Position", pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]: ) -> Tuple[BaseStrategy, BaseExecutor]:
@@ -222,7 +218,7 @@ def backtest(
strategy: Union[str, dict, object, Path], strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path], executor: Union[str, dict, object, Path],
benchmark: str = "SH000300", benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9, account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {}, exchange_kwargs: dict = {},
pos_type: str = "Position", pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]: ) -> Tuple[PortfolioMetrics, Indicator]:
@@ -285,7 +281,7 @@ def collect_data(
strategy: Union[str, dict, object, Path], strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path], executor: Union[str, dict, object, Path],
benchmark: str = "SH000300", benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9, account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {}, exchange_kwargs: dict = {},
pos_type: str = "Position", pos_type: str = "Position",
return_value: dict = None, return_value: dict = None,
@@ -339,7 +335,7 @@ def format_decisions(
cur_freq = decisions[0].strategy.trade_calendar.get_freq() cur_freq = decisions[0].strategy.trade_calendar.get_freq()
res = (cur_freq, []) res: Tuple[str, list] = (cur_freq, [])
last_dec_idx = 0 last_dec_idx = 0
for i, dec in enumerate(decisions[1:], 1): for i, dec in enumerate(decisions[1:], 1):
if dec.strategy.trade_calendar.get_freq() == cur_freq: if dec.strategy.trade_calendar.get_freq() == cur_freq:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple, cast
import pandas as pd import pandas as pd
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
from .decision import BaseTradeDecision, Order from .decision import BaseTradeDecision, Order
from .exchange import Exchange from .exchange import Exchange
from .high_performance_ds import BaseOrderIndicator
from .position import BasePosition from .position import BasePosition
from .report import Indicator, PortfolioMetrics from .report import Indicator, PortfolioMetrics
@@ -104,7 +105,7 @@ class Account:
self._pos_type = pos_type self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled self._port_metr_enabled = port_metr_enabled
self.benchmark_config = None # avoid no attribute error self.benchmark_config: dict = {} # avoid no attribute error
self.init_vars(init_cash, position_dict, freq, benchmark_config) self.init_vars(init_cash, position_dict, freq, benchmark_config)
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None: def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
@@ -124,8 +125,8 @@ class Account:
self.accum_info = AccumulatedInfo() self.accum_info = AccumulatedInfo()
# 2) following variables are not shared between layers # 2) following variables are not shared between layers
self.portfolio_metrics = None self.portfolio_metrics: Optional[PortfolioMetrics] = None
self.hist_positions = {} self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
self.reset(freq=freq, benchmark_config=benchmark_config) self.reset(freq=freq, benchmark_config=benchmark_config)
def is_port_metr_enabled(self) -> bool: def is_port_metr_enabled(self) -> bool:
@@ -171,7 +172,7 @@ class Account:
self.reset_report(self.freq, self.benchmark_config) self.reset_report(self.freq, self.benchmark_config)
def get_hist_positions(self) -> dict: def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
return self.hist_positions return self.hist_positions
def get_cash(self) -> float: def get_cash(self) -> float:
@@ -230,13 +231,15 @@ class Account:
""" """
# update price for stock in the position and the profit from changed_price # update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
assert self.current_position is not None
if not self.current_position.skip_update(): if not self.current_position.skip_update():
stock_list = self.current_position.get_stock_list() stock_list = self.current_position.get_stock_list()
for code in stock_list: for code in stock_list:
# if suspend, no new price to be updated, profit is 0 # if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
self.current_position.update_stock_price(stock_id=code, price=bar_close) self.current_position.update_stock_price(stock_id=code, price=bar_close)
# update holding day count # update holding day count
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
@@ -249,6 +252,8 @@ class Account:
# for the first trade date, account_value - init_cash # for the first trade date, account_value - init_cash
# self.portfolio_metrics.is_empty() to judge is_first_trade_date # self.portfolio_metrics.is_empty() to judge is_first_trade_date
# get last_account_value, last_total_cost, last_total_turnover # get last_account_value, last_total_cost, last_total_turnover
assert self.portfolio_metrics is not None
if self.portfolio_metrics.is_empty(): if self.portfolio_metrics.is_empty():
last_account_value = self.init_cash last_account_value = self.init_cash
last_total_cost = 0 last_total_cost = 0
@@ -299,9 +304,9 @@ class Account:
trade_exchange: Exchange, trade_exchange: Exchange,
atomic: bool, atomic: bool,
outer_trade_decision: BaseTradeDecision, outer_trade_decision: BaseTradeDecision,
trade_info: list = None, trade_info: list = [],
inner_order_indicators: List[Dict[str, pd.Series]] = None, inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {}, indicator_config: dict = {},
) -> None: ) -> None:
"""update trade indicators and order indicators in each bar end""" """update trade indicators and order indicators in each bar end"""
@@ -335,9 +340,9 @@ class Account:
trade_exchange: Exchange, trade_exchange: Exchange,
atomic: bool, atomic: bool,
outer_trade_decision: BaseTradeDecision, outer_trade_decision: BaseTradeDecision,
trade_info: list = None, trade_info: list = [],
inner_order_indicators: List[Dict[str, pd.Series]] = None, inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {}, indicator_config: dict = {},
) -> None: ) -> None:
"""update account at each trading bar step """update account at each trading bar step
@@ -398,6 +403,7 @@ class Account:
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]: def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
"""get the history portfolio_metrics and positions instance""" """get the history portfolio_metrics and positions instance"""
if self.is_port_metr_enabled(): if self.is_port_metr_enabled():
assert self.portfolio_metrics is not None
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe() _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
_positions = self.get_hist_positions() _positions = self.get_hist_positions()
return _portfolio_metrics, _positions return _portfolio_metrics, _positions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
import pandas as pd import pandas as pd
@@ -36,10 +36,13 @@ def backtest_loop(
indicator: Indicator indicator: Indicator
it computes the trading indicator it computes the trading indicator
""" """
return_value = {} return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass pass
return return_value.get("portfolio_metrics"), return_value.get("indicator")
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
def collect_data_loop( def collect_data_loop(

View File

@@ -7,7 +7,7 @@ from abc import abstractmethod
from enum import IntEnum from enum import IntEnum
# try to fix circular imports when enabling type hints # try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
from qlib.backtest.utils import TradeCalendarManager from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal from qlib.data.data import Cal
@@ -24,6 +24,9 @@ import numpy as np
import pandas as pd import pandas as pd
DecisionType = TypeVar("DecisionType")
class OrderDir(IntEnum): class OrderDir(IntEnum):
# Order direction # Order direction
SELL = 0 SELL = 0
@@ -65,7 +68,7 @@ class Order:
# - not tradable: the deal_amount == 0 , factor is None # - not tradable: the deal_amount == 0 , factor is None
# - the stock is suspended and the entire order fails. No cost for this order # - the stock is suspended and the entire order fails. No cost for this order
# - dealt or partially dealt: deal_amount >= 0 and factor is not None # - dealt or partially dealt: deal_amount >= 0 and factor is not None
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value deal_amount: float = 0.0 # `deal_amount` is a non-negative value
factor: Optional[float] = None factor: Optional[float] = None
# TODO: # TODO:
@@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange):
return max(val_start, start_time), min(val_end, end_time) return max(val_start, start_time), min(val_end, end_time)
class BaseTradeDecision: class BaseTradeDecision(Generic[DecisionType]):
""" """
Trade decisions ara made by strategy and executed by executor Trade decisions ara made by strategy and executed by executor
@@ -316,20 +319,21 @@ class BaseTradeDecision:
""" """
self.strategy = strategy self.strategy = strategy
self.start_time, self.end_time = strategy.trade_calendar.get_step_time() self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading` # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
if isinstance(trade_range, Tuple): self.total_step: Optional[int] = None
if isinstance(trade_range, tuple):
# for Tuple[int, int] # for Tuple[int, int]
trade_range = IdxTradeRange(*trade_range) trade_range = IdxTradeRange(*trade_range)
self.trade_range: TradeRange = trade_range self.trade_range: Optional[TradeRange] = trade_range
def get_decision(self) -> List[object]: def get_decision(self) -> List[DecisionType]:
""" """
get the **concrete decision** (e.g. execution orders) get the **concrete decision** (e.g. execution orders)
This will be called by the inner strategy This will be called by the inner strategy
Returns Returns
------- -------
List[object]: List[DecisionType:
The decision result. Typically it is some orders The decision result. Typically it is some orders
Example: Example:
[]: []:
@@ -363,13 +367,13 @@ class BaseTradeDecision:
# purpose 2) # purpose 2)
return self.strategy.update_trade_decision(self, trade_calendar) return self.strategy.update_trade_decision(self, trade_calendar)
def _get_range_limit(self, **kwargs) -> Tuple[int, int]: def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
if self.trade_range is not None: if self.trade_range is not None:
return self.trade_range(trade_calendar=kwargs.get("inner_calendar")) return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
else: else:
raise NotImplementedError("The decision didn't provide an index range") raise NotImplementedError("The decision didn't provide an index range")
def get_range_limit(self, **kwargs) -> Tuple[int, int]: def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
""" """
return the expected step range for limiting the decision execution time return the expected step range for limiting the decision execution time
Both left and right are **closed** Both left and right are **closed**
@@ -421,6 +425,7 @@ class BaseTradeDecision:
if getattr(self, "total_step", None) is not None: if getattr(self, "total_step", None) is not None:
# if `self.update` is called. # if `self.update` is called.
# Then the _start_idx, _end_idx should be clipped # Then the _start_idx, _end_idx should be clipped
assert self.total_step is not None
if _start_idx < 0 or _end_idx >= self.total_step: if _start_idx < 0 or _end_idx >= self.total_step:
logger = get_module_logger("decision") logger = get_module_logger("decision")
logger.warning( logger.warning(
@@ -516,7 +521,7 @@ class BaseTradeDecision:
inner_trade_decision.trade_range = self.trade_range inner_trade_decision.trade_range = self.trade_range
class EmptyTradeDecision(BaseTradeDecision): class EmptyTradeDecision(BaseTradeDecision[object]):
def get_decision(self) -> List[object]: def get_decision(self) -> List[object]:
return [] return []
@@ -524,23 +529,24 @@ class EmptyTradeDecision(BaseTradeDecision):
return True return True
class TradeDecisionWO(BaseTradeDecision): class TradeDecisionWO(BaseTradeDecision[Order]):
""" """
Trade Decision (W)ith (O)rder. Trade Decision (W)ith (O)rder.
Besides, the time_range is also included. Besides, the time_range is also included.
""" """
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None): def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
super().__init__(strategy, trade_range=trade_range) super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time() start, end = strategy.trade_calendar.get_step_time()
for o in order_list: for o in order_list:
assert isinstance(o, Order)
if o.start_time is None: if o.start_time is None:
o.start_time = start o.start_time = start
if o.end_time is None: if o.end_time is None:
o.end_time = end o.end_time = end
def get_decision(self) -> List[object]: def get_decision(self) -> List[Order]:
return self.order_list return self.order_list
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from ..utils.index_data import IndexData from ..utils.index_data import IndexData
@@ -42,7 +42,7 @@ class Exchange:
impact_cost: float = 0.0, impact_cost: float = 0.0,
extra_quote: pd.DataFrame = None, extra_quote: pd.DataFrame = None,
quote_cls: Type[BaseQuote] = NumpyQuote, quote_cls: Type[BaseQuote] = NumpyQuote,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
"""__init__ """__init__
:param freq: frequency of data :param freq: frequency of data
@@ -141,7 +141,7 @@ class Exchange:
if limit_threshold is None: if limit_threshold is None:
if C.region == REG_CN: if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1: elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
if C.region == REG_CN: if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value") self.logger.warning(f"limit_threshold may not be set to a reasonable value")
@@ -150,7 +150,7 @@ class Exchange:
deal_price = "$" + deal_price deal_price = "$" + deal_price
self.buy_price = self.sell_price = deal_price self.buy_price = self.sell_price = deal_price
elif isinstance(deal_price, (tuple, list)): elif isinstance(deal_price, (tuple, list)):
self.buy_price, self.sell_price = deal_price self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
else: else:
raise NotImplementedError(f"This type of input is not supported") raise NotImplementedError(f"This type of input is not supported")
@@ -167,10 +167,10 @@ class Exchange:
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP: if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple)
for exp in limit_threshold: for exp in limit_threshold:
necessary_fields.add(exp) necessary_fields.add(exp)
all_fields = necessary_fields | set(vol_lt_fields) all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
all_fields = list(all_fields | set(subscribe_fields))
self.all_fields = all_fields self.all_fields = all_fields
@@ -249,9 +249,9 @@ class Exchange:
LT_FLT = "float" # float LT_FLT = "float" # float
LT_NONE = "none" # none LT_NONE = "none" # none
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str: def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type""" """get limit type"""
if isinstance(limit_threshold, Tuple): if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP return self.LT_TP_EXP
elif isinstance(limit_threshold, float): elif isinstance(limit_threshold, float):
return self.LT_FLT return self.LT_FLT
@@ -268,14 +268,16 @@ class Exchange:
self.quote_df["limit_sell"] = False self.quote_df["limit_sell"] = False
elif limit_type == self.LT_TP_EXP: elif limit_type == self.LT_TP_EXP:
# set limit # set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]] self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]] self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
elif limit_type == self.LT_FLT: elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130 self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
@staticmethod @staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]: def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
""" """
preprocess the volume limit. preprocess the volume limit.
get the fields need to get from qlib. get the fields need to get from qlib.
@@ -340,11 +342,11 @@ class Exchange:
if direction is None: if direction is None:
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return buy_limit or sell_limit return bool(buy_limit or sell_limit)
elif direction == Order.BUY: elif direction == Order.BUY:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
elif direction == Order.SELL: elif direction == Order.SELL:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
else: else:
raise ValueError(f"direction {direction} is not supported!") raise ValueError(f"direction {direction} is not supported!")
@@ -382,7 +384,7 @@ class Exchange:
order: Order, order: Order,
trade_account: Account = None, trade_account: Account = None,
position: BasePosition = None, position: BasePosition = None,
dealt_order_amount: defaultdict = defaultdict(float), dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]: ) -> Tuple[float, float, float]:
""" """
Deal order when the actual transaction Deal order when the actual transaction
@@ -426,9 +428,10 @@ class Exchange:
stock_id: str, stock_id: str,
start_time: pd.Timestamp, start_time: pd.Timestamp,
end_time: pd.Timestamp, end_time: pd.Timestamp,
field: str,
method: str = "ts_data_last", method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]: ) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`? return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
def get_close( def get_close(
self, self,
@@ -444,10 +447,10 @@ class Exchange:
stock_id: str, stock_id: str,
start_time: pd.Timestamp, start_time: pd.Timestamp,
end_time: pd.Timestamp, end_time: pd.Timestamp,
method: str = "sum", method: Optional[str] = "sum",
) -> float: ) -> float:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
def get_deal_price( def get_deal_price(
self, self,
@@ -455,7 +458,7 @@ class Exchange:
start_time: pd.Timestamp, start_time: pd.Timestamp,
end_time: pd.Timestamp, end_time: pd.Timestamp,
direction: OrderDir, direction: OrderDir,
method: str = "ts_data_last", method: Optional[str] = "ts_data_last",
) -> float: ) -> float:
if direction == OrderDir.SELL: if direction == OrderDir.SELL:
pstr = self.sell_price pstr = self.sell_price
@@ -469,7 +472,7 @@ class Exchange:
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price") self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method) deal_price = self.get_close(stock_id, start_time, end_time, method)
return deal_price return cast(float, deal_price)
def get_factor( def get_factor(
self, self,
@@ -544,7 +547,7 @@ class Exchange:
) )
return amount_dict return amount_dict
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float: def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
""" """
Calculate the real adjust deal amount when considering the trading unit Calculate the real adjust deal amount when considering the trading unit
:param current_amount: :param current_amount:
@@ -572,7 +575,7 @@ class Exchange:
current_position: dict, current_position: dict,
start_time: pd.Timestamp, start_time: pd.Timestamp,
end_time: pd.Timestamp, end_time: pd.Timestamp,
) -> list: ) -> List[Order]:
""" """
Note: some future information is used in this function Note: some future information is used in this function
Parameter: Parameter:
@@ -681,6 +684,7 @@ class Exchange:
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time) factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
else: else:
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
assert factor is not None
return factor return factor
def get_amount_of_trade_unit( def get_amount_of_trade_unit(
@@ -718,12 +722,12 @@ class Exchange:
def round_amount_by_trade_unit( def round_amount_by_trade_unit(
self, self,
deal_amount, deal_amount: float,
factor: float = None, factor: float = None,
stock_id: str = None, stock_id: str = None,
start_time=None, start_time: pd.Timestamp = None,
end_time=None, end_time: pd.Timestamp = None,
): ) -> float:
"""Parameter """Parameter
Please refer to the docs of get_amount_of_trade_unit Please refer to the docs of get_amount_of_trade_unit
deal_amount : float, adjusted amount deal_amount : float, adjusted amount
@@ -741,7 +745,7 @@ class Exchange:
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount return deal_amount
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int: def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
"""parse the capacity limit string and return the actual amount of orders that can be executed. """parse the capacity limit string and return the actual amount of orders that can be executed.
NOTE: NOTE:
this function will change the order.deal_amount **inplace** this function will change the order.deal_amount **inplace**
@@ -753,15 +757,12 @@ class Exchange:
dealt_order_amount : dict dealt_order_amount : dict
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float} :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
""" """
if order.direction == Order.BUY: vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
vol_limit = self.buy_vol_limit
elif order.direction == Order.SELL:
vol_limit = self.sell_vol_limit
if vol_limit is None: if vol_limit is None:
return order.deal_amount return order.deal_amount
vol_limit_num = [] vol_limit_num: List[float] = []
for limit in vol_limit: for limit in vol_limit:
assert isinstance(limit, tuple) assert isinstance(limit, tuple)
if limit[0] == "current": if limit[0] == "current":
@@ -772,7 +773,7 @@ class Exchange:
field=limit[1], field=limit[1],
method="sum", method="sum",
) )
vol_limit_num.append(limit_value) vol_limit_num.append(cast(float, limit_value))
elif limit[0] == "cum": elif limit[0] == "cum":
limit_value = self.quote.get_data( limit_value = self.quote.get_data(
order.stock_id, order.stock_id,
@@ -790,12 +791,14 @@ class Exchange:
if vol_limit_min < orig_deal_amount: if vol_limit_min < orig_deal_amount:
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}") self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): return None
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
"""return the real order amount after cash limit for buying. """return the real order amount after cash limit for buying.
Parameters Parameters
---------- ----------
trade_price : float trade_price : float
position : cash cash : float
cost_ratio : float cost_ratio : float
Return Return
@@ -803,7 +806,7 @@ class Exchange:
float float
the real order amount after cash limit for buying. the real order amount after cash limit for buying.
""" """
max_trade_amount = 0 max_trade_amount = 0.0
if cash >= self.min_cost: if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost. # critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / cost_ratio + self.min_cost critical_price = self.min_cost / cost_ratio + self.min_cost
@@ -897,7 +900,7 @@ class Exchange:
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else: else:
raise NotImplementedError("order type {} error".format(order.type)) raise NotImplementedError("order direction {} error".format(order.direction))
trade_val = order.deal_amount * trade_price trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * cost_ratio, self.min_cost) trade_cost = max(trade_val * cost_ratio, self.min_cost)

View File

@@ -4,7 +4,7 @@ import copy
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from types import GeneratorType from types import GeneratorType
from typing import Generator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Tuple, Union, cast
import pandas as pd import pandas as pd
@@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy
from ..utils import init_instance_by_config from ..utils import init_instance_by_config
from .decision import BaseTradeDecision, Order from .decision import BaseTradeDecision, Order
from .exchange import Exchange from .exchange import Exchange
from .utils import ( from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
BaseInfrastructure,
CommonInfrastructure,
LevelInfrastructure,
TradeCalendarManager,
get_start_end_idx,
)
class BaseExecutor: class BaseExecutor:
@@ -39,8 +33,8 @@ class BaseExecutor:
track_data: bool = False, track_data: bool = False,
trade_exchange: Exchange = None, trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None, common_infra: CommonInfrastructure = None,
settle_type=BasePosition.ST_NO, # TODO: add typehint settle_type: str = BasePosition.ST_NO,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
""" """
Parameters Parameters
@@ -127,10 +121,10 @@ class BaseExecutor:
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}") get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
# record deal order amount in one day # record deal order amount in one day
self.dealt_order_amount = defaultdict(float) self.dealt_order_amount: Dict[str, float] = defaultdict(float)
self.deal_day = None self.deal_day = None
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None: def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
""" """
reset infrastructure for trading reset infrastructure for trading
- reset trade_account - reset trade_account
@@ -141,14 +135,15 @@ class BaseExecutor:
self.common_infra.update(common_infra) self.common_infra.update(common_infra)
if common_infra.has("trade_account"): if common_infra.has("trade_account"):
if copy_trade_account:
# NOTE: there is a trick in the code. # NOTE: there is a trick in the code.
# shallow copy is used instead of deepcopy. # shallow copy is used instead of deepcopy.
# 1. So positions are shared # 1. So positions are shared
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics) # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
self.trade_account: Account = copy.copy(common_infra.get("trade_account")) self.trade_account: Account = (
else: copy.copy(common_infra.get("trade_account"))
self.trade_account: Account = common_infra.get("trade_account") if copy_trade_account
else common_infra.get("trade_account")
)
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics) self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
@property @property
@@ -164,7 +159,7 @@ class BaseExecutor:
""" """
return self.level_infra.get("trade_calendar") return self.level_infra.get("trade_calendar")
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None: def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
""" """
- reset `start_time` and `end_time`, used in trade calendar - reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -200,20 +195,17 @@ class BaseExecutor:
execute_result : List[object] execute_result : List[object]
the executed result for trade decision the executed result for trade decision
""" """
return_value = {} return_value: dict = {}
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level): for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
pass pass
return return_value.get("execute_result") return cast(list, return_value.get("execute_result"))
@abstractmethod @abstractmethod
def _collect_data( def _collect_data(
self, self,
trade_decision: BaseTradeDecision, trade_decision: BaseTradeDecision,
level: int = 0, level: int = 0,
) -> Union[ ) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
Tuple[List[object], dict],
]:
""" """
Please refer to the doc of collect_data Please refer to the doc of collect_data
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
@@ -235,7 +227,7 @@ class BaseExecutor:
trade_decision: BaseTradeDecision, trade_decision: BaseTradeDecision,
return_value: dict = None, return_value: dict = None,
level: int = 0, level: int = 0,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]: ) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training """Generator for collecting the trade decision data for rl training
his function will make a step forward his function will make a step forward
@@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor):
skip_empty_decision: bool = True, skip_empty_decision: bool = True,
align_range_limit: bool = True, align_range_limit: bool = True,
common_infra: CommonInfrastructure = None, common_infra: CommonInfrastructure = None,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
""" """
Parameters Parameters
@@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor):
self, self,
trade_decision: BaseTradeDecision, trade_decision: BaseTradeDecision,
level: int = 0, level: int = 0,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]: ) -> Generator[Any, Any, Tuple[List[object], dict]]:
execute_result = [] execute_result = []
inner_order_indicators = [] inner_order_indicators = []
decision_list = [] decision_list = []
@@ -493,7 +485,7 @@ class NestedExecutor(BaseExecutor):
the execution result of inner task the execution result of inner task
""" """
def get_all_executors(self) -> List[object]: def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors, including self and inner_executor.get_all_executors()""" """get all executors, including self and inner_executor.get_all_executors()"""
return [self, *self.inner_executor.get_all_executors()] return [self, *self.inner_executor.get_all_executors()]
@@ -536,7 +528,7 @@ class SimulatorExecutor(BaseExecutor):
track_data: bool = False, track_data: bool = False,
common_infra: CommonInfrastructure = None, common_infra: CommonInfrastructure = None,
trade_type: str = TT_SERIAL, trade_type: str = TT_SERIAL,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
""" """
Parameters Parameters
@@ -598,7 +590,7 @@ class SimulatorExecutor(BaseExecutor):
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
trade_start_time, _ = self.trade_calendar.get_step_time() trade_start_time, _ = self.trade_calendar.get_step_time()
execute_result = [] execute_result: list = []
for order in self._get_order_iterator(trade_decision): for order in self._get_order_iterator(trade_decision):
# execute the order. # execute the order.

View File

@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from __future__ import annotations
import inspect import inspect
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache from functools import lru_cache
from typing import Callable, Dict, Iterable, List, Text, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value
class BaseQuote: class BaseQuote:
def __init__(self, quote_df: pd.DataFrame, freq): def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
self.logger = get_module_logger("online operator", level=logging.INFO) self.logger = get_module_logger("online operator", level=logging.INFO)
def get_all_stock(self) -> Iterable: def get_all_stock(self) -> Iterable:
@@ -39,7 +41,7 @@ class BaseQuote:
start_time: Union[pd.Timestamp, str], start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str],
field: Union[str], field: Union[str],
method: Union[str, None] = None, method: Optional[str] = None,
) -> Union[None, int, float, bool, IndexData]: ) -> Union[None, int, float, bool, IndexData]:
"""get the specific field of stock data during start time and end_time, """get the specific field of stock data during start time and end_time,
and apply method to the data. and apply method to the data.
@@ -99,7 +101,7 @@ class BaseQuote:
class PandasQuote(BaseQuote): class PandasQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq): def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
super().__init__(quote_df=quote_df, freq=freq) super().__init__(quote_df=quote_df, freq=freq)
quote_dict = {} quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"): for stock_id, stock_val in quote_df.groupby(level="instrument"):
@@ -124,7 +126,7 @@ class PandasQuote(BaseQuote):
class NumpyQuote(BaseQuote): class NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"): def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
"""NumpyQuote """NumpyQuote
Parameters Parameters
@@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote):
data = self._agg_data(data, method) data = self._agg_data(data, method)
return data return data
def _agg_data(self, data: IndexData, method): @staticmethod
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
"""Agg data by specific method.""" """Agg data by specific method."""
# FIXME: why not call the method of data directly? # FIXME: why not call the method of data directly?
if method == "sum": if method == "sum":
@@ -224,31 +227,31 @@ class BaseSingleMetric:
""" """
raise NotImplementedError(f"Please implement the `__init__` method") raise NotImplementedError(f"Please implement the `__init__` method")
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__add__` method") raise NotImplementedError(f"Please implement the `__add__` method")
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
return self + other return self + other
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__sub__` method") raise NotImplementedError(f"Please implement the `__sub__` method")
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__rsub__` method") raise NotImplementedError(f"Please implement the `__rsub__` method")
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__mul__` method") raise NotImplementedError(f"Please implement the `__mul__` method")
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__truediv__` method") raise NotImplementedError(f"Please implement the `__truediv__` method")
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __eq__(self, other: object) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__eq__` method") raise NotImplementedError(f"Please implement the `__eq__` method")
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__gt__` method") raise NotImplementedError(f"Please implement the `__gt__` method")
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric": def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__lt__` method") raise NotImplementedError(f"Please implement the `__lt__` method")
def __len__(self) -> int: def __len__(self) -> int:
@@ -265,7 +268,7 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `count` method") raise NotImplementedError(f"Please implement the `count` method")
def abs(self) -> "BaseSingleMetric": def abs(self) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `abs` method") raise NotImplementedError(f"Please implement the `abs` method")
@property @property
@@ -274,17 +277,17 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `empty` method") raise NotImplementedError(f"Please implement the `empty` method")
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric": def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
"""Replace np.NaN with fill_value in two metrics and add them.""" """Replace np.NaN with fill_value in two metrics and add them."""
raise NotImplementedError(f"Please implement the `add` method") raise NotImplementedError(f"Please implement the `add` method")
def replace(self, replace_dict: dict) -> "BaseSingleMetric": def replace(self, replace_dict: dict) -> BaseSingleMetric:
"""Replace the value of metric according to replace_dict.""" """Replace the value of metric according to replace_dict."""
raise NotImplementedError(f"Please implement the `replace` method") raise NotImplementedError(f"Please implement the `replace` method")
def apply(self, func: dict) -> "BaseSingleMetric": def apply(self, func: Callable) -> BaseSingleMetric:
"""Replace the value of metric with func (metric). """Replace the value of metric with func (metric).
Currently, the func is only qlib/backtest/order/Order.parse_dir. Currently, the func is only qlib/backtest/order/Order.parse_dir.
""" """
@@ -304,11 +307,11 @@ class BaseOrderIndicator:
to inherit the BaseSingleMetric. to inherit the BaseSingleMetric.
""" """
def __init__(self, data): def __init__(self):
self.data = data self.data = {} # will be created in the subclass
self.logger = get_module_logger("online operator") self.logger = get_module_logger("online operator")
def assign(self, col: str, metric: Union[dict, pd.Series]): def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
"""assign one metric. """assign one metric.
Parameters Parameters
@@ -328,7 +331,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'assign' method") raise NotImplementedError(f"Please implement the 'assign' method")
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]: def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
"""compute new metric with existing metrics. """compute new metric with existing metrics.
Parameters Parameters
@@ -352,6 +355,7 @@ class BaseOrderIndicator:
tmp_metric = func(**func_kwargs) tmp_metric = func(**func_kwargs)
if new_col is not None: if new_col is not None:
self.data[new_col] = tmp_metric self.data[new_col] = tmp_metric
return None
else: else:
return tmp_metric return tmp_metric
@@ -372,7 +376,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_metric_series' method") raise NotImplementedError(f"Please implement the 'get_metric_series' method")
def get_index_data(self, metric) -> SingleData: def get_index_data(self, metric: str) -> SingleData:
"""get one metric with the format of SingleData """get one metric with the format of SingleData
Parameters Parameters
@@ -389,7 +393,12 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_index_data' method") raise NotImplementedError(f"Please implement the 'get_index_data' method")
@staticmethod @staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None): def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
"""sum indicators with the same metrics. """sum indicators with the same metrics.
and assign to the order_indicator(BaseOrderIndicator). and assign to the order_indicator(BaseOrderIndicator).
NOTE: indicators could be a empty list when orders in lower level all fail. NOTE: indicators could be a empty list when orders in lower level all fail.
@@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
def index(self): def index(self):
return list(self.metric.index) return list(self.metric.index)
def add(self, other, fill_value=None): def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
other = cast(PandasSingleMetric, other)
return self.__class__(self.metric.add(other.metric, fill_value=fill_value)) return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
def replace(self, replace_dict: dict): def replace(self, replace_dict: dict) -> PandasSingleMetric:
return self.__class__(self.metric.replace(replace_dict)) return self.__class__(self.metric.replace(replace_dict))
def apply(self, func: Callable): def apply(self, func: Callable) -> PandasSingleMetric:
return self.__class__(self.metric.apply(func)) return self.__class__(self.metric.apply(func))
def reindex(self, index, fill_value): def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
return self.__class__(self.metric.reindex(index, fill_value=fill_value)) return self.__class__(self.metric.reindex(index, fill_value=fill_value))
def __repr__(self): def __repr__(self):
@@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
Str is the name of metric. Str is the name of metric.
""" """
def __init__(self): def __init__(self) -> None:
super(PandasOrderIndicator, self).__init__()
self.data: Dict[str, PandasSingleMetric] = OrderedDict() self.data: Dict[str, PandasSingleMetric] = OrderedDict()
def assign(self, col: str, metric: Union[dict, pd.Series]): def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
self.data[col] = PandasSingleMetric(metric) self.data[col] = PandasSingleMetric(metric)
def get_index_data(self, metric): def get_index_data(self, metric: str) -> SingleData:
if metric in self.data: if metric in self.data:
return idd.SingleData(self.data[metric].metric) return idd.SingleData(self.data[metric].metric)
else: else:
@@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
return {k: v.metric for k, v in self.data.items()} return {k: v.metric for k, v in self.data.items()}
@staticmethod @staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
if isinstance(metrics, str): if isinstance(metrics, str):
metrics = [metrics] metrics = [metrics]
for metric in metrics: for metric in metrics:
@@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
Str is the name of metric. Str is the name of metric.
""" """
def __init__(self): def __init__(self) -> None:
super(NumpyOrderIndicator, self).__init__()
self.data: Dict[str, SingleData] = OrderedDict() self.data: Dict[str, SingleData] = OrderedDict()
def assign(self, col: str, metric: dict): def assign(self, col: str, metric: dict) -> None:
self.data[col] = idd.SingleData(metric) self.data[col] = idd.SingleData(metric)
def get_index_data(self, metric): def get_index_data(self, metric: str) -> SingleData:
if metric in self.data: if metric in self.data:
return self.data[metric] return self.data[metric]
else: else:
@@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
return tmp_metric_dict return tmp_metric_dict
@staticmethod @staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0): def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
# get all index(stock_id) # get all index(stock_id)
stocks = set() stock_set: set = set()
for indicator in indicators: for indicator in indicators:
# set(np.ndarray.tolist()) is faster than set(np.ndarray) # set(np.ndarray.tolist()) is faster than set(np.ndarray)
stocks = stocks | set(indicator.data[metrics[0]].index.tolist()) stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
stocks = list(stocks) stocks = sorted(list(stock_set))
stocks.sort()
# add metric by index # add metric by index
if isinstance(metrics, str): if isinstance(metrics, str):

View File

@@ -3,7 +3,7 @@
from datetime import timedelta from datetime import timedelta
from typing import Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -18,9 +18,9 @@ class BasePosition:
Please refer to the `Position` class for the position Please refer to the `Position` class for the position
""" """
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None: def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
self._settle_type = self.ST_NO self._settle_type = self.ST_NO
self.position = {} self.position: dict = {}
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
pass pass
@@ -96,13 +96,13 @@ class BasePosition:
def calculate_value(self) -> float: def calculate_value(self) -> float:
raise NotImplementedError(f"Please implement the `calculate_value` method") raise NotImplementedError(f"Please implement the `calculate_value` method")
def get_stock_list(self) -> List: def get_stock_list(self) -> List[str]:
""" """
Get the list of stocks in the position. Get the list of stocks in the position.
""" """
raise NotImplementedError(f"Please implement the `get_stock_list` method") raise NotImplementedError(f"Please implement the `get_stock_list` method")
def get_stock_price(self, code) -> float: def get_stock_price(self, code: str) -> float:
""" """
get the latest price of the stock get the latest price of the stock
@@ -113,7 +113,7 @@ class BasePosition:
""" """
raise NotImplementedError(f"Please implement the `get_stock_price` method") raise NotImplementedError(f"Please implement the `get_stock_price` method")
def get_stock_amount(self, code) -> float: def get_stock_amount(self, code: str) -> float:
""" """
get the amount of the stock get the amount of the stock
@@ -144,7 +144,7 @@ class BasePosition:
""" """
raise NotImplementedError(f"Please implement the `get_cash` method") raise NotImplementedError(f"Please implement the `get_cash` method")
def get_stock_amount_dict(self) -> Dict: def get_stock_amount_dict(self) -> dict:
""" """
generate stock amount dict {stock_id : amount of stock} generate stock amount dict {stock_id : amount of stock}
@@ -155,7 +155,7 @@ class BasePosition:
""" """
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method") raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
""" """
generate stock weight dict {stock_id : value weight of stock in the position} generate stock weight dict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade step it is meaningful in the beginning or the end of each trade step
@@ -174,7 +174,7 @@ class BasePosition:
""" """
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method") raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
def add_count_all(self, bar) -> None: def add_count_all(self, bar: str) -> None:
""" """
Will be called at the end of each bar on each level Will be called at the end of each bar on each level
@@ -195,7 +195,7 @@ class BasePosition:
raise NotImplementedError(f"Please implement the `add_count_all` method") raise NotImplementedError(f"Please implement the `add_count_all` method")
ST_CASH = "cash" ST_CASH = "cash"
ST_NO = None ST_NO = "None" # String is more typehint friendly than None
def settle_start(self, settle_type: str) -> None: def settle_start(self, settle_type: str) -> None:
""" """
@@ -220,10 +220,10 @@ class BasePosition:
""" """
raise NotImplementedError(f"Please implement the `settle_commit` method") raise NotImplementedError(f"Please implement the `settle_commit` method")
def __str__(self): def __str__(self) -> str:
return self.__dict__.__str__() return self.__dict__.__str__()
def __repr__(self): def __repr__(self) -> str:
return self.__dict__.__repr__() return self.__dict__.__repr__()
@@ -532,7 +532,7 @@ class InfPosition(BasePosition):
def calculate_value(self) -> float: def calculate_value(self) -> float:
raise NotImplementedError(f"InfPosition doesn't support calculating value") raise NotImplementedError(f"InfPosition doesn't support calculating value")
def get_stock_list(self) -> list: def get_stock_list(self) -> List[str]:
raise NotImplementedError(f"InfPosition doesn't support stock list position") raise NotImplementedError(f"InfPosition doesn't support stock list position")
def get_stock_price(self, code: str) -> float: def get_stock_price(self, code: str) -> float:
@@ -545,10 +545,10 @@ class InfPosition(BasePosition):
def get_cash(self, include_settle: bool = False) -> float: def get_cash(self, include_settle: bool = False) -> float:
return np.inf return np.inf
def get_stock_amount_dict(self) -> Dict: def get_stock_amount_dict(self) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict") raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict") raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
def add_count_all(self, bar: str) -> None: def add_count_all(self, bar: str) -> None:

View File

@@ -4,7 +4,7 @@
import pathlib import pathlib
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Tuple, Union from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange
from ..tests.config import CSI300_BENCH from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
class PortfolioMetrics: class PortfolioMetrics:
@@ -38,7 +38,7 @@ class PortfolioMetrics:
update report update report
""" """
def __init__(self, freq: str = "day", benchmark_config: dict = {}): def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
""" """
Parameters Parameters
---------- ----------
@@ -49,13 +49,17 @@ class PortfolioMetrics:
- benchmark : Union[str, list, pd.Series] - benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example: example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) print(
D.features(D.instruments('csi500'),
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
)
2017-01-04 0.011693 2017-01-04 0.011693
2017-01-05 0.000721 2017-01-05 0.000721
2017-01-06 -0.004322 2017-01-06 -0.004322
2017-01-09 0.006874 2017-01-09 0.006874
2017-01-10 -0.003350 2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. - If `benchmark` is list, will use the daily average change of the stock pool in the list as the
'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'. - If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300 benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional - start_time : Union[str, pd.Timestamp], optional
@@ -70,25 +74,26 @@ class PortfolioMetrics:
self.init_vars() self.init_vars()
self.init_bench(freq=freq, benchmark_config=benchmark_config) self.init_bench(freq=freq, benchmark_config=benchmark_config)
def init_vars(self): def init_vars(self) -> None:
self.accounts = OrderedDict() # account position value for each trade time self.accounts: dict = OrderedDict() # account position value for each trade time
self.returns = OrderedDict() # daily return rate for each trade time self.returns: dict = OrderedDict() # daily return rate for each trade time
self.total_turnovers = OrderedDict() # total turnover for each trade time self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
self.turnovers = OrderedDict() # turnover for each trade time self.turnovers: dict = OrderedDict() # turnover for each trade time
self.total_costs = OrderedDict() # total trade cost for each trade time self.total_costs: dict = OrderedDict() # total trade cost for each trade time
self.costs = OrderedDict() # trade cost rate for each trade time self.costs: dict = OrderedDict() # trade cost rate for each trade time
self.values = OrderedDict() # value for each trade time self.values: dict = OrderedDict() # value for each trade time
self.cashes = OrderedDict() self.cashes: dict = OrderedDict()
self.benches = OrderedDict() self.benches: dict = OrderedDict()
self.latest_pm_time = None # pd.TimeStamp self.latest_pm_time: Optional[pd.TimeStamp] = None
def init_bench(self, freq=None, benchmark_config=None): def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
if freq is not None: if freq is not None:
self.freq = freq self.freq = freq
self.benchmark_config = benchmark_config self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(self.benchmark_config, self.freq) self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _cal_benchmark(self, benchmark_config, freq): @staticmethod
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
if benchmark_config is None: if benchmark_config is None:
return None return None
benchmark = benchmark_config.get("benchmark", CSI300_BENCH) benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
@@ -110,7 +115,12 @@ class PortfolioMetrics:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time): def _sample_benchmark(
self,
bench: pd.Series,
trade_start_time: Union[str, pd.Timestamp],
trade_end_time: Union[str, pd.Timestamp],
) -> Optional[float]:
if self.bench is None: if self.bench is None:
return None return None
@@ -120,35 +130,35 @@ class PortfolioMetrics:
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change) _ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
return 0.0 if _ret is None else _ret - 1 return 0.0 if _ret is None else _ret - 1
def is_empty(self): def is_empty(self) -> bool:
return len(self.accounts) == 0 return len(self.accounts) == 0
def get_latest_date(self): def get_latest_date(self) -> pd.Timestamp:
return self.latest_pm_time return self.latest_pm_time
def get_latest_account_value(self): def get_latest_account_value(self) -> float:
return self.accounts[self.latest_pm_time] return self.accounts[self.latest_pm_time]
def get_latest_total_cost(self): def get_latest_total_cost(self) -> Any:
return self.total_costs[self.latest_pm_time] return self.total_costs[self.latest_pm_time]
def get_latest_total_turnover(self): def get_latest_total_turnover(self) -> Any:
return self.total_turnovers[self.latest_pm_time] return self.total_turnovers[self.latest_pm_time]
def update_portfolio_metrics_record( def update_portfolio_metrics_record(
self, self,
trade_start_time=None, trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time=None, trade_end_time: Union[str, pd.Timestamp] = None,
account_value=None, account_value: float = None,
cash=None, cash: float = None,
return_rate=None, return_rate: float = None,
total_turnover=None, total_turnover: float = None,
turnover_rate=None, turnover_rate: float = None,
total_cost=None, total_cost: float = None,
cost_rate=None, cost_rate: float = None,
stock_value=None, stock_value: float = None,
bench_value=None, bench_value: float = None,
): ) -> None:
# check data # check data
if None in [ if None in [
trade_start_time, trade_start_time,
@@ -185,7 +195,7 @@ class PortfolioMetrics:
self.latest_pm_time = trade_start_time self.latest_pm_time = trade_start_time
# finish pm update in each step # finish pm update in each step
def generate_portfolio_metrics_dataframe(self): def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
pm = pd.DataFrame() pm = pd.DataFrame()
pm["account"] = pd.Series(self.accounts) pm["account"] = pd.Series(self.accounts)
pm["return"] = pd.Series(self.returns) pm["return"] = pd.Series(self.returns)
@@ -199,19 +209,18 @@ class PortfolioMetrics:
pm.index.name = "datetime" pm.index.name = "datetime"
return pm return pm
def save_portfolio_metrics(self, path): def save_portfolio_metrics(self, path: str) -> None:
r = self.generate_portfolio_metrics_dataframe() r = self.generate_portfolio_metrics_dataframe()
r.to_csv(path) r.to_csv(path)
def load_portfolio_metrics(self, path): def load_portfolio_metrics(self, path: str) -> None:
"""load pm from a file """load pm from a file
should have format like should have format like
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench'] columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
:param :param
path: str/ pathlib.Path() path: str/ pathlib.Path()
""" """
path = pathlib.Path(path) with pathlib.Path(path).open("rb") as f:
with path.open("rb") as f:
r = pd.read_csv(f, index_col=0) r = pd.read_csv(f, index_col=0)
r.index = pd.DatetimeIndex(r.index) r.index = pd.DatetimeIndex(r.index)
@@ -261,30 +270,30 @@ class Indicator:
""" """
def __init__(self, order_indicator_cls=NumpyOrderIndicator): def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
self.order_indicator_cls = order_indicator_cls self.order_indicator_cls = order_indicator_cls
# order indicator is metrics for a single order for a specific step # order indicator is metrics for a single order for a specific step
self.order_indicator_his = OrderedDict() self.order_indicator_his: dict = OrderedDict()
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
# trade indicator is metrics for all orders for a specific step # trade indicator is metrics for all orders for a specific step
self.trade_indicator_his = OrderedDict() self.trade_indicator_his: dict = OrderedDict()
self.trade_indicator: Dict[str, float] = OrderedDict() self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
self._trade_calendar = None self._trade_calendar = None
# def reset(self, trade_calendar: TradeCalendarManager): # def reset(self, trade_calendar: TradeCalendarManager):
def reset(self): def reset(self) -> None:
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() self.order_indicator = self.order_indicator_cls()
self.trade_indicator = OrderedDict() self.trade_indicator = OrderedDict()
# self._trade_calendar = trade_calendar # self._trade_calendar = trade_calendar
def record(self, trade_start_time): def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
self.order_indicator_his[trade_start_time] = self.get_order_indicator() self.order_indicator_his[trade_start_time] = self.get_order_indicator()
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator() self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
def _update_order_trade_info(self, trade_info: list): def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
amount = dict() amount = dict()
deal_amount = dict() deal_amount = dict()
trade_price = dict() trade_price = dict()
@@ -313,7 +322,7 @@ class Indicator:
self.order_indicator.assign("trade_dir", trade_dir) self.order_indicator.assign("trade_dir", trade_dir)
self.order_indicator.assign("pa", pa) self.order_indicator.assign("pa", pa)
def _update_order_fulfill_rate(self): def _update_order_fulfill_rate(self) -> None:
def func(deal_amount, amount): def func(deal_amount, amount):
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0. # deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
tmp_deal_amount = deal_amount.reindex(amount.index, 0) tmp_deal_amount = deal_amount.reindex(amount.index, 0)
@@ -322,11 +331,11 @@ class Indicator:
self.order_indicator.transfer(func, "ffr") self.order_indicator.transfer(func, "ffr")
def update_order_indicators(self, trade_info: list): def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
self._update_order_trade_info(trade_info=trade_info) self._update_order_trade_info(trade_info=trade_info)
self._update_order_fulfill_rate() self._update_order_fulfill_rate()
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]): def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
# calculate total trade amount with each inner order indicator. # calculate total trade amount with each inner order indicator.
def trade_amount_func(deal_amount, trade_price): def trade_amount_func(deal_amount, trade_price):
return deal_amount * trade_price return deal_amount * trade_price
@@ -355,9 +364,9 @@ class Indicator:
self.order_indicator.transfer(func_apply, "trade_dir") self.order_indicator.transfer(func_apply, "trade_dir")
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision): def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
# NOTE: these indicator is designed for order execution, so the # NOTE: these indicator is designed for order execution, so the
decision: List[Order] = outer_trade_decision.get_decision() decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
if len(decision) == 0: if len(decision) == 0:
self.order_indicator.assign("amount", {}) self.order_indicator.assign("amount", {})
else: else:
@@ -372,7 +381,7 @@ class Indicator:
decision: BaseTradeDecision, decision: BaseTradeDecision,
trade_exchange: Exchange, trade_exchange: Exchange,
pa_config: dict = {}, pa_config: dict = {},
): ) -> Tuple[Optional[float], Optional[float]]:
""" """
Get the base volume and price information Get the base volume and price information
All the base price values are rooted from this function All the base price values are rooted from this function
@@ -412,31 +421,35 @@ class Indicator:
# NOTE: there are some zeros in the trading price. These cases are known meaningless # NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it. # for aligning the previous logic, remove it.
# remove zero and negative values. # remove zero and negative values.
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)] assert isinstance(price_s, idd.SingleData)
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8 # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
# ~(np.NaN < 1e-8) -> ~(False) -> True # ~(np.NaN < 1e-8) -> ~(False) -> True
assert isinstance(price_s, idd.SingleData)
if agg == "vwap": if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float, np.number)): if isinstance(volume_s, (int, float, np.number)):
volume_s = idd.SingleData(volume_s, [trade_start_time]) volume_s = idd.SingleData(volume_s, [trade_start_time])
assert isinstance(volume_s, idd.SingleData)
volume_s = volume_s.reindex(price_s.index) volume_s = volume_s.reindex(price_s.index)
elif agg == "twap": elif agg == "twap":
volume_s = idd.SingleData(1, price_s.index) volume_s = idd.SingleData(1, price_s.index)
else: else:
raise NotImplementedError(f"This type of input is not supported") raise NotImplementedError(f"This type of input is not supported")
assert isinstance(volume_s, idd.SingleData)
base_volume = volume_s.sum() base_volume = volume_s.sum()
base_price = (price_s * volume_s).sum() / base_volume base_price = (price_s * volume_s).sum() / base_volume
return base_price, base_volume return base_price, base_volume
def _agg_base_price( def _agg_base_price(
self, self,
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]], inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
trade_exchange: Exchange, trade_exchange: Exchange,
pa_config: dict = {}, pa_config: dict = {},
): ) -> None:
""" """
# NOTE:!!!! # NOTE:!!!!
# Strong assumption!!!!!! # Strong assumption!!!!!!
@@ -444,7 +457,7 @@ class Indicator:
Parameters Parameters
---------- ----------
inner_order_indicators : List[Dict[str, pd.Series]] inner_order_indicators : List[BaseOrderIndicator]
the indicators of account of inner executor the indicators of account of inner executor
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
a list of decisions according to inner_order_indicators a list of decisions according to inner_order_indicators
@@ -489,14 +502,17 @@ class Indicator:
bv_new = idd.SingleData(bv_new) bv_new = idd.SingleData(bv_new)
bp_all.append(bp_new) bp_all.append(bp_new)
bv_all.append(bv_new) bv_all.append(bv_new)
bp_all = idd.concat(bp_all, axis=1) bp_all_multi_data = idd.concat(bp_all, axis=1)
bv_all = idd.concat(bv_all, axis=1) bv_all_multi_data = idd.concat(bv_all, axis=1)
base_volume = bv_all.sum(axis=1) base_volume = bv_all_multi_data.sum(axis=1)
self.order_indicator.assign("base_volume", base_volume.to_dict()) self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict()) self.order_indicator.assign(
"base_price",
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
)
def _agg_order_price_advantage(self): def _agg_order_price_advantage(self) -> None:
def if_empty_func(trade_price): def if_empty_func(trade_price):
return trade_price.empty return trade_price.empty
@@ -513,12 +529,12 @@ class Indicator:
def agg_order_indicators( def agg_order_indicators(
self, self,
inner_order_indicators: List[Dict[str, pd.Series]], inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]], decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
outer_trade_decision: BaseTradeDecision, outer_trade_decision: BaseTradeDecision,
trade_exchange: Exchange, trade_exchange: Exchange,
indicator_config={}, indicator_config: dict = {},
): ) -> None:
self._agg_order_trade_info(inner_order_indicators) self._agg_order_trade_info(inner_order_indicators)
self._update_trade_amount(outer_trade_decision) self._update_trade_amount(outer_trade_decision)
self._update_order_fulfill_rate() self._update_order_fulfill_rate()
@@ -526,71 +542,66 @@ class Indicator:
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
self._agg_order_price_advantage() self._agg_order_price_advantage()
def _cal_trade_fulfill_rate(self, method="mean"): def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean": if method == "mean":
return self.order_indicator.transfer(
def func(ffr): lambda ffr: ffr.mean(),
return ffr.mean() )
elif method == "amount_weighted": elif method == "amount_weighted":
return self.order_indicator.transfer(
def func(ffr, deal_amount): lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()) )
elif method == "value_weighted": elif method == "value_weighted":
return self.order_indicator.transfer(
def func(ffr, trade_value): lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()) )
else: else:
raise ValueError(f"method {method} is not supported!") raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_price_advantage(self, method="mean"): def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean": if method == "mean":
return self.order_indicator.transfer(lambda pa: pa.mean())
def func(pa):
return pa.mean()
elif method == "amount_weighted": elif method == "amount_weighted":
return self.order_indicator.transfer(
def func(pa, deal_amount): lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()) )
elif method == "value_weighted": elif method == "value_weighted":
return self.order_indicator.transfer(
def func(pa, trade_value): lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum()) )
else: else:
raise ValueError(f"method {method} is not supported!") raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_positive_rate(self): def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
def func(pa): def func(pa):
return (pa > 0).sum() / pa.count() return (pa > 0).sum() / pa.count()
return self.order_indicator.transfer(func) return self.order_indicator.transfer(func)
def _cal_deal_amount(self): def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
def func(deal_amount): def func(deal_amount):
return deal_amount.abs().sum() return deal_amount.abs().sum()
return self.order_indicator.transfer(func) return self.order_indicator.transfer(func)
def _cal_trade_value(self): def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
def func(trade_value): def func(trade_value):
return trade_value.abs().sum() return trade_value.abs().sum()
return self.order_indicator.transfer(func) return self.order_indicator.transfer(func)
def _cal_trade_order_count(self): def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
def func(amount): def func(amount):
return amount.count() return amount.count()
return self.order_indicator.transfer(func) return self.order_indicator.transfer(func)
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}): def cal_trade_indicators(
self,
trade_start_time: Union[str, pd.Timestamp],
freq: str,
indicator_config: dict = {},
) -> None:
show_indicator = indicator_config.get("show_indicator", False) show_indicator = indicator_config.get("show_indicator", False)
ffr_config = indicator_config.get("ffr_config", {}) ffr_config = indicator_config.get("ffr_config", {})
pa_config = indicator_config.get("pa_config", {}) pa_config = indicator_config.get("pa_config", {})
@@ -608,22 +619,22 @@ class Indicator:
self.trade_indicator["count"] = order_count self.trade_indicator["count"] = order_count
if show_indicator: if show_indicator:
print( print(
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format( "[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq, freq,
trade_start_time, trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
fulfill_rate, fulfill_rate,
price_advantage, price_advantage,
positive_rate, positive_rate,
), ),
) )
def get_order_indicator(self, raw: bool = True): def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
if raw: return self.order_indicator if raw else self.order_indicator.to_series()
return self.order_indicator
return self.order_indicator.to_series()
def get_trade_indicator(self): def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
return self.trade_indicator return self.trade_indicator
def generate_trade_indicators_dataframe(self): def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index") return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")

View File

@@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]: def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
""" """
get the signal at the end of the decision step(from `start_time` to `end_time`) get the signal at the end of the decision step(from `start_time` to `end_time`)
@@ -39,13 +39,14 @@ class SignalWCache(Signal):
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
""" """
def __init__(self, signal: Union[pd.Series, pd.DataFrame]): def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
""" """
Parameters Parameters
---------- ----------
signal : Union[pd.Series, pd.DataFrame] signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted) The expected format of the signal is like the data below (the order of index is not important and can be
automatically adjusted)
instrument datetime instrument datetime
SH600000 2008-01-02 0.079704 SH600000 2008-01-02 0.079704
@@ -56,8 +57,8 @@ class SignalWCache(Signal):
""" """
self.signal_cache = convert_index_format(signal, level="datetime") self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]: def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy # the frequency of the signal may not align with the decision frequency of strategy
# so resampling from the data is necessary # so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading. # the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
@@ -65,7 +66,7 @@ class SignalWCache(Signal):
class ModelSignal(SignalWCache): class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset): def __init__(self, model: BaseModel, dataset: Dataset) -> None:
self.model = model self.model = model
self.dataset = dataset self.dataset = dataset
pred_scores = self.model.predict(dataset) pred_scores = self.model.predict(dataset)
@@ -73,7 +74,7 @@ class ModelSignal(SignalWCache):
pred_scores = pred_scores.iloc[:, 0] pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores) super().__init__(pred_scores)
def _update_model(self): def _update_model(self) -> None:
""" """
When using online data, update model in each bar as the following steps: When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update - update dataset with online data, the dataset should support online update

View File

@@ -149,6 +149,8 @@ class TradeCalendarManager:
Tuple[int, int]: Tuple[int, int]:
""" """
# potential performance issue # potential performance issue
assert self.level_infra is not None
day_start = pd.Timestamp(self.start_time.date()) day_start = pd.Timestamp(self.start_time.date())
day_end = epsilon_change(day_start + pd.Timedelta(days=1)) day_end = epsilon_change(day_start + pd.Timedelta(days=1))
freq = self.level_infra.get("common_infra").get("trade_exchange").freq freq = self.level_infra.get("common_infra").get("trade_exchange").freq
@@ -182,8 +184,8 @@ class TradeCalendarManager:
Tuple[int, int]: Tuple[int, int]:
the index of the range. **the left and right are closed** the index of the range. **the left and right are closed**
""" """
left = bisect.bisect_right(self._calendar, start_time) - 1 left = bisect.bisect_right(list(self._calendar), start_time) - 1
right = bisect.bisect_right(self._calendar, end_time) - 1 right = bisect.bisect_right(list(self._calendar), end_time) - 1
left -= self.start_index left -= self.start_index
right -= self.start_index right -= self.start_index
@@ -201,14 +203,14 @@ class TradeCalendarManager:
class BaseInfrastructure: class BaseInfrastructure:
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: Any) -> None:
self.reset_infra(**kwargs) self.reset_infra(**kwargs)
@abstractmethod @abstractmethod
def get_support_infra(self) -> Set[str]: def get_support_infra(self) -> Set[str]:
raise NotImplementedError("`get_support_infra` is not implemented!") raise NotImplementedError("`get_support_infra` is not implemented!")
def reset_infra(self, **kwargs) -> None: def reset_infra(self, **kwargs: Any) -> None:
support_infra = self.get_support_infra() support_infra = self.get_support_infra()
for k, v in kwargs.items(): for k, v in kwargs.items():
if k in support_infra: if k in support_infra:

View File

@@ -339,7 +339,7 @@ def long_short_backtest(
for stock in long_stocks: for stock in long_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit): if np.isnan(profit):
long_profit.append(0) long_profit.append(0)
else: else:
@@ -348,17 +348,17 @@ def long_short_backtest(
for stock in short_stocks: for stock in short_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit): if np.isnan(profit):
short_profit.append(0) short_profit.append(0)
else: else:
short_profit.append(-profit) short_profit.append(profit * -1)
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)): for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
# exclude the suspend stock # exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date): if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit): if np.isnan(profit):
all_profit.append(0) all_profit.append(0)
else: else:

View File

@@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC):
_, _, si, ei = self.locate_index(start_time, end_time, freq, future) _, _, si, ei = self.locate_index(start_time, end_time, freq, future)
return _calendar[si : ei + 1] return _calendar[si : ei + 1]
def locate_index(self, start_time, end_time, freq, future=False): def locate_index(
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
):
"""Locate the start time index and end time index in a calendar under certain frequency. """Locate the start time index and end time index in a calendar under certain frequency.
Parameters Parameters
---------- ----------
start_time : str start_time : pd.Timestamp
start of the time range. start of the time range.
end_time : str end_time : pd.Timestamp
end of the time range. end of the time range.
freq : str freq : str
time frequency, available: year/quarter/month/week/day. time frequency, available: year/quarter/month/week/day.

View File

@@ -248,7 +248,7 @@ def load_orders(
Order( Order(
row["instrument"], row["instrument"],
row["amount"], row["amount"],
int(row["order_type"]), OrderDir(int(row["order_type"])),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
) )

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Generator, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange from qlib.backtest.exchange import Exchange
@@ -122,7 +122,10 @@ class BaseStrategy:
self.outer_trade_decision = outer_trade_decision self.outer_trade_decision = outer_trade_decision
@abstractmethod @abstractmethod
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: def generate_trade_decision(
self,
execute_result: list = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
"""Generate trade decision in each trading bar """Generate trade decision in each trading bar
Parameters Parameters

View File

@@ -9,6 +9,8 @@ Motivation of index_data
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors. `index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
""" """
from __future__ import annotations
from typing import Dict, Tuple, Union, Callable, List from typing import Dict, Tuple, Union, Callable, List
import bisect import bisect
@@ -16,7 +18,7 @@ import numpy as np
import pandas as pd import pandas as pd
def concat(data_list: Union["SingleData"], axis=0) -> "MultiData": def concat(data_list: Union[SingleData], axis=0) -> MultiData:
"""concat all SingleData by index. """concat all SingleData by index.
TODO: now just for SingleData. TODO: now just for SingleData.
@@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
raise ValueError(f"axis must be 0 or 1") raise ValueError(f"axis must be 0 or 1")
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData": def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:
"""concat all SingleData by new index. """concat all SingleData by new index.
Parameters Parameters
@@ -554,7 +556,7 @@ class SingleData(IndexData):
f"The indexes of self and other do not meet the requirements of the four arithmetic operations" f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
) )
def reindex(self, index: Index, fill_value=np.NaN): def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
"""reindex data and fill the missing value with np.NaN. """reindex data and fill the missing value with np.NaN.
Parameters Parameters
@@ -580,7 +582,7 @@ class SingleData(IndexData):
pass pass
return SingleData(tmp_data, index) return SingleData(tmp_data, index)
def add(self, other: "SingleData", fill_value=0): def add(self, other: SingleData, fill_value=0):
# TODO: add and __add__ are a little confusing. # TODO: add and __add__ are a little confusing.
# This could be a more general # This could be a more general
common_index = self.index | other.index common_index = self.index | other.index