mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Backtest Mypy (#1130)
* Done * Fix test errors * Revert profit_attribution.py * Minor * A minor update on collect_data type hint * Resolve PR comments * Use black to format code * Fix CI errors
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
[mypy]
|
[mypy]
|
||||||
exclude = (?x)(
|
exclude = (?x)(
|
||||||
^qlib/backtest
|
^qlib/backtest/high_performance_ds\.py$
|
||||||
| ^qlib/contrib
|
| ^qlib/contrib
|
||||||
| ^qlib/data
|
| ^qlib/data
|
||||||
| ^qlib/model
|
| ^qlib/model
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@@ -23,7 +23,6 @@ from ..utils import init_instance_by_config
|
|||||||
from .backtest import backtest_loop, collect_data_loop
|
from .backtest import backtest_loop, collect_data_loop
|
||||||
from .decision import Order
|
from .decision import Order
|
||||||
from .exchange import Exchange
|
from .exchange import Exchange
|
||||||
from .position import Position
|
|
||||||
from .utils import CommonInfrastructure
|
from .utils import CommonInfrastructure
|
||||||
|
|
||||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||||
@@ -44,7 +43,7 @@ def get_exchange(
|
|||||||
min_cost: float = 5.0,
|
min_cost: float = 5.0,
|
||||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> Exchange:
|
) -> Exchange:
|
||||||
"""get_exchange
|
"""get_exchange
|
||||||
|
|
||||||
@@ -52,14 +51,15 @@ def get_exchange(
|
|||||||
----------
|
----------
|
||||||
|
|
||||||
# exchange related arguments
|
# exchange related arguments
|
||||||
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
|
exchange: Exchange
|
||||||
|
It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||||
freq: str
|
freq: str
|
||||||
frequency of data.
|
frequency of data.
|
||||||
start_time: Union[pd.Timestamp, str]
|
start_time: Union[pd.Timestamp, str]
|
||||||
closed start time for backtest.
|
closed start time for backtest.
|
||||||
end_time: Union[pd.Timestamp, str]
|
end_time: Union[pd.Timestamp, str]
|
||||||
closed end time for backtest.
|
closed end time for backtest.
|
||||||
codes: list|str
|
codes: Union[list, str]
|
||||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||||
subscribe_fields: list
|
subscribe_fields: list
|
||||||
subscribe fields.
|
subscribe fields.
|
||||||
@@ -151,28 +151,24 @@ def create_account_instance(
|
|||||||
Postion type.
|
Postion type.
|
||||||
"""
|
"""
|
||||||
if isinstance(account, (int, float)):
|
if isinstance(account, (int, float)):
|
||||||
pos_kwargs = {"init_cash": account}
|
init_cash = account
|
||||||
|
position_dict = {}
|
||||||
elif isinstance(account, dict):
|
elif isinstance(account, dict):
|
||||||
init_cash = account["cash"]
|
init_cash = account.pop("cash")
|
||||||
del account["cash"]
|
position_dict = account
|
||||||
pos_kwargs = {
|
|
||||||
"init_cash": init_cash,
|
|
||||||
"position_dict": account,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("account must be in (int, float, Position)")
|
raise ValueError("account must be in (int, float, dict)")
|
||||||
|
|
||||||
kwargs = {
|
return Account(
|
||||||
"init_cash": account,
|
init_cash=init_cash,
|
||||||
"benchmark_config": {
|
position_dict=position_dict,
|
||||||
|
pos_type=pos_type,
|
||||||
|
benchmark_config={
|
||||||
"benchmark": benchmark,
|
"benchmark": benchmark,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
},
|
},
|
||||||
"pos_type": pos_type,
|
)
|
||||||
}
|
|
||||||
kwargs.update(pos_kwargs)
|
|
||||||
return Account(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_strategy_executor(
|
def get_strategy_executor(
|
||||||
@@ -181,7 +177,7 @@ def get_strategy_executor(
|
|||||||
strategy: Union[str, dict, object, Path],
|
strategy: Union[str, dict, object, Path],
|
||||||
executor: Union[str, dict, object, Path],
|
executor: Union[str, dict, object, Path],
|
||||||
benchmark: str = "SH000300",
|
benchmark: str = "SH000300",
|
||||||
account: Union[float, int, Position] = 1e9,
|
account: Union[float, int, dict] = 1e9,
|
||||||
exchange_kwargs: dict = {},
|
exchange_kwargs: dict = {},
|
||||||
pos_type: str = "Position",
|
pos_type: str = "Position",
|
||||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||||
@@ -222,7 +218,7 @@ def backtest(
|
|||||||
strategy: Union[str, dict, object, Path],
|
strategy: Union[str, dict, object, Path],
|
||||||
executor: Union[str, dict, object, Path],
|
executor: Union[str, dict, object, Path],
|
||||||
benchmark: str = "SH000300",
|
benchmark: str = "SH000300",
|
||||||
account: Union[float, int, Position] = 1e9,
|
account: Union[float, int, dict] = 1e9,
|
||||||
exchange_kwargs: dict = {},
|
exchange_kwargs: dict = {},
|
||||||
pos_type: str = "Position",
|
pos_type: str = "Position",
|
||||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||||
@@ -285,7 +281,7 @@ def collect_data(
|
|||||||
strategy: Union[str, dict, object, Path],
|
strategy: Union[str, dict, object, Path],
|
||||||
executor: Union[str, dict, object, Path],
|
executor: Union[str, dict, object, Path],
|
||||||
benchmark: str = "SH000300",
|
benchmark: str = "SH000300",
|
||||||
account: Union[float, int, Position] = 1e9,
|
account: Union[float, int, dict] = 1e9,
|
||||||
exchange_kwargs: dict = {},
|
exchange_kwargs: dict = {},
|
||||||
pos_type: str = "Position",
|
pos_type: str = "Position",
|
||||||
return_value: dict = None,
|
return_value: dict = None,
|
||||||
@@ -339,7 +335,7 @@ def format_decisions(
|
|||||||
|
|
||||||
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
||||||
|
|
||||||
res = (cur_freq, [])
|
res: Tuple[str, list] = (cur_freq, [])
|
||||||
last_dec_idx = 0
|
last_dec_idx = 0
|
||||||
for i, dec in enumerate(decisions[1:], 1):
|
for i, dec in enumerate(decisions[1:], 1):
|
||||||
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
|
|||||||
|
|
||||||
from .decision import BaseTradeDecision, Order
|
from .decision import BaseTradeDecision, Order
|
||||||
from .exchange import Exchange
|
from .exchange import Exchange
|
||||||
|
from .high_performance_ds import BaseOrderIndicator
|
||||||
from .position import BasePosition
|
from .position import BasePosition
|
||||||
from .report import Indicator, PortfolioMetrics
|
from .report import Indicator, PortfolioMetrics
|
||||||
|
|
||||||
@@ -104,7 +105,7 @@ class Account:
|
|||||||
|
|
||||||
self._pos_type = pos_type
|
self._pos_type = pos_type
|
||||||
self._port_metr_enabled = port_metr_enabled
|
self._port_metr_enabled = port_metr_enabled
|
||||||
self.benchmark_config = None # avoid no attribute error
|
self.benchmark_config: dict = {} # avoid no attribute error
|
||||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||||
|
|
||||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||||
@@ -124,8 +125,8 @@ class Account:
|
|||||||
self.accum_info = AccumulatedInfo()
|
self.accum_info = AccumulatedInfo()
|
||||||
|
|
||||||
# 2) following variables are not shared between layers
|
# 2) following variables are not shared between layers
|
||||||
self.portfolio_metrics = None
|
self.portfolio_metrics: Optional[PortfolioMetrics] = None
|
||||||
self.hist_positions = {}
|
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
|
||||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||||
|
|
||||||
def is_port_metr_enabled(self) -> bool:
|
def is_port_metr_enabled(self) -> bool:
|
||||||
@@ -171,7 +172,7 @@ class Account:
|
|||||||
|
|
||||||
self.reset_report(self.freq, self.benchmark_config)
|
self.reset_report(self.freq, self.benchmark_config)
|
||||||
|
|
||||||
def get_hist_positions(self) -> dict:
|
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
|
||||||
return self.hist_positions
|
return self.hist_positions
|
||||||
|
|
||||||
def get_cash(self) -> float:
|
def get_cash(self) -> float:
|
||||||
@@ -230,13 +231,15 @@ class Account:
|
|||||||
"""
|
"""
|
||||||
# update price for stock in the position and the profit from changed_price
|
# update price for stock in the position and the profit from changed_price
|
||||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||||
|
assert self.current_position is not None
|
||||||
|
|
||||||
if not self.current_position.skip_update():
|
if not self.current_position.skip_update():
|
||||||
stock_list = self.current_position.get_stock_list()
|
stock_list = self.current_position.get_stock_list()
|
||||||
for code in stock_list:
|
for code in stock_list:
|
||||||
# if suspend, no new price to be updated, profit is 0
|
# if suspend, no new price to be updated, profit is 0
|
||||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||||
continue
|
continue
|
||||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
|
||||||
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
||||||
# update holding day count
|
# update holding day count
|
||||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||||
@@ -249,6 +252,8 @@ class Account:
|
|||||||
# for the first trade date, account_value - init_cash
|
# for the first trade date, account_value - init_cash
|
||||||
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
||||||
# get last_account_value, last_total_cost, last_total_turnover
|
# get last_account_value, last_total_cost, last_total_turnover
|
||||||
|
assert self.portfolio_metrics is not None
|
||||||
|
|
||||||
if self.portfolio_metrics.is_empty():
|
if self.portfolio_metrics.is_empty():
|
||||||
last_account_value = self.init_cash
|
last_account_value = self.init_cash
|
||||||
last_total_cost = 0
|
last_total_cost = 0
|
||||||
@@ -299,9 +304,9 @@ class Account:
|
|||||||
trade_exchange: Exchange,
|
trade_exchange: Exchange,
|
||||||
atomic: bool,
|
atomic: bool,
|
||||||
outer_trade_decision: BaseTradeDecision,
|
outer_trade_decision: BaseTradeDecision,
|
||||||
trade_info: list = None,
|
trade_info: list = [],
|
||||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||||
indicator_config: dict = {},
|
indicator_config: dict = {},
|
||||||
) -> None:
|
) -> None:
|
||||||
"""update trade indicators and order indicators in each bar end"""
|
"""update trade indicators and order indicators in each bar end"""
|
||||||
@@ -335,9 +340,9 @@ class Account:
|
|||||||
trade_exchange: Exchange,
|
trade_exchange: Exchange,
|
||||||
atomic: bool,
|
atomic: bool,
|
||||||
outer_trade_decision: BaseTradeDecision,
|
outer_trade_decision: BaseTradeDecision,
|
||||||
trade_info: list = None,
|
trade_info: list = [],
|
||||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||||
indicator_config: dict = {},
|
indicator_config: dict = {},
|
||||||
) -> None:
|
) -> None:
|
||||||
"""update account at each trading bar step
|
"""update account at each trading bar step
|
||||||
@@ -398,6 +403,7 @@ class Account:
|
|||||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||||
"""get the history portfolio_metrics and positions instance"""
|
"""get the history portfolio_metrics and positions instance"""
|
||||||
if self.is_port_metr_enabled():
|
if self.is_port_metr_enabled():
|
||||||
|
assert self.portfolio_metrics is not None
|
||||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||||
_positions = self.get_hist_positions()
|
_positions = self.get_hist_positions()
|
||||||
return _portfolio_metrics, _positions
|
return _portfolio_metrics, _positions
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@@ -36,10 +36,13 @@ def backtest_loop(
|
|||||||
indicator: Indicator
|
indicator: Indicator
|
||||||
it computes the trading indicator
|
it computes the trading indicator
|
||||||
"""
|
"""
|
||||||
return_value = {}
|
return_value: dict = {}
|
||||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||||
pass
|
pass
|
||||||
return return_value.get("portfolio_metrics"), return_value.get("indicator")
|
|
||||||
|
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||||
|
indicator = cast(Indicator, return_value.get("indicator"))
|
||||||
|
return portfolio_metrics, indicator
|
||||||
|
|
||||||
|
|
||||||
def collect_data_loop(
|
def collect_data_loop(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from abc import abstractmethod
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
# try to fix circular imports when enabling type hints
|
# try to fix circular imports when enabling type hints
|
||||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
|
||||||
|
|
||||||
from qlib.backtest.utils import TradeCalendarManager
|
from qlib.backtest.utils import TradeCalendarManager
|
||||||
from qlib.data.data import Cal
|
from qlib.data.data import Cal
|
||||||
@@ -24,6 +24,9 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
DecisionType = TypeVar("DecisionType")
|
||||||
|
|
||||||
|
|
||||||
class OrderDir(IntEnum):
|
class OrderDir(IntEnum):
|
||||||
# Order direction
|
# Order direction
|
||||||
SELL = 0
|
SELL = 0
|
||||||
@@ -65,7 +68,7 @@ class Order:
|
|||||||
# - not tradable: the deal_amount == 0 , factor is None
|
# - not tradable: the deal_amount == 0 , factor is None
|
||||||
# - the stock is suspended and the entire order fails. No cost for this order
|
# - the stock is suspended and the entire order fails. No cost for this order
|
||||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
|
||||||
factor: Optional[float] = None
|
factor: Optional[float] = None
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
@@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange):
|
|||||||
return max(val_start, start_time), min(val_end, end_time)
|
return max(val_start, start_time), min(val_end, end_time)
|
||||||
|
|
||||||
|
|
||||||
class BaseTradeDecision:
|
class BaseTradeDecision(Generic[DecisionType]):
|
||||||
"""
|
"""
|
||||||
Trade decisions ara made by strategy and executed by executor
|
Trade decisions ara made by strategy and executed by executor
|
||||||
|
|
||||||
@@ -316,20 +319,21 @@ class BaseTradeDecision:
|
|||||||
"""
|
"""
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
||||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||||
if isinstance(trade_range, Tuple):
|
self.total_step: Optional[int] = None
|
||||||
|
if isinstance(trade_range, tuple):
|
||||||
# for Tuple[int, int]
|
# for Tuple[int, int]
|
||||||
trade_range = IdxTradeRange(*trade_range)
|
trade_range = IdxTradeRange(*trade_range)
|
||||||
self.trade_range: TradeRange = trade_range
|
self.trade_range: Optional[TradeRange] = trade_range
|
||||||
|
|
||||||
def get_decision(self) -> List[object]:
|
def get_decision(self) -> List[DecisionType]:
|
||||||
"""
|
"""
|
||||||
get the **concrete decision** (e.g. execution orders)
|
get the **concrete decision** (e.g. execution orders)
|
||||||
This will be called by the inner strategy
|
This will be called by the inner strategy
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[object]:
|
List[DecisionType:
|
||||||
The decision result. Typically it is some orders
|
The decision result. Typically it is some orders
|
||||||
Example:
|
Example:
|
||||||
[]:
|
[]:
|
||||||
@@ -363,13 +367,13 @@ class BaseTradeDecision:
|
|||||||
# purpose 2)
|
# purpose 2)
|
||||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||||
|
|
||||||
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||||
if self.trade_range is not None:
|
if self.trade_range is not None:
|
||||||
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
|
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The decision didn't provide an index range")
|
raise NotImplementedError("The decision didn't provide an index range")
|
||||||
|
|
||||||
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
return the expected step range for limiting the decision execution time
|
return the expected step range for limiting the decision execution time
|
||||||
Both left and right are **closed**
|
Both left and right are **closed**
|
||||||
@@ -421,6 +425,7 @@ class BaseTradeDecision:
|
|||||||
if getattr(self, "total_step", None) is not None:
|
if getattr(self, "total_step", None) is not None:
|
||||||
# if `self.update` is called.
|
# if `self.update` is called.
|
||||||
# Then the _start_idx, _end_idx should be clipped
|
# Then the _start_idx, _end_idx should be clipped
|
||||||
|
assert self.total_step is not None
|
||||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||||
logger = get_module_logger("decision")
|
logger = get_module_logger("decision")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -516,7 +521,7 @@ class BaseTradeDecision:
|
|||||||
inner_trade_decision.trade_range = self.trade_range
|
inner_trade_decision.trade_range = self.trade_range
|
||||||
|
|
||||||
|
|
||||||
class EmptyTradeDecision(BaseTradeDecision):
|
class EmptyTradeDecision(BaseTradeDecision[object]):
|
||||||
def get_decision(self) -> List[object]:
|
def get_decision(self) -> List[object]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -524,23 +529,24 @@ class EmptyTradeDecision(BaseTradeDecision):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TradeDecisionWO(BaseTradeDecision):
|
class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||||
"""
|
"""
|
||||||
Trade Decision (W)ith (O)rder.
|
Trade Decision (W)ith (O)rder.
|
||||||
Besides, the time_range is also included.
|
Besides, the time_range is also included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
|
||||||
super().__init__(strategy, trade_range=trade_range)
|
super().__init__(strategy, trade_range=trade_range)
|
||||||
self.order_list = order_list
|
self.order_list = cast(List[Order], order_list)
|
||||||
start, end = strategy.trade_calendar.get_step_time()
|
start, end = strategy.trade_calendar.get_step_time()
|
||||||
for o in order_list:
|
for o in order_list:
|
||||||
|
assert isinstance(o, Order)
|
||||||
if o.start_time is None:
|
if o.start_time is None:
|
||||||
o.start_time = start
|
o.start_time = start
|
||||||
if o.end_time is None:
|
if o.end_time is None:
|
||||||
o.end_time = end
|
o.end_time = end
|
||||||
|
|
||||||
def get_decision(self) -> List[object]:
|
def get_decision(self) -> List[Order]:
|
||||||
return self.order_list
|
return self.order_list
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||||
|
|
||||||
from ..utils.index_data import IndexData
|
from ..utils.index_data import IndexData
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ class Exchange:
|
|||||||
impact_cost: float = 0.0,
|
impact_cost: float = 0.0,
|
||||||
extra_quote: pd.DataFrame = None,
|
extra_quote: pd.DataFrame = None,
|
||||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""__init__
|
"""__init__
|
||||||
:param freq: frequency of data
|
:param freq: frequency of data
|
||||||
@@ -141,7 +141,7 @@ class Exchange:
|
|||||||
if limit_threshold is None:
|
if limit_threshold is None:
|
||||||
if C.region == REG_CN:
|
if C.region == REG_CN:
|
||||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||||
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
|
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||||
if C.region == REG_CN:
|
if C.region == REG_CN:
|
||||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||||
|
|
||||||
@@ -150,7 +150,7 @@ class Exchange:
|
|||||||
deal_price = "$" + deal_price
|
deal_price = "$" + deal_price
|
||||||
self.buy_price = self.sell_price = deal_price
|
self.buy_price = self.sell_price = deal_price
|
||||||
elif isinstance(deal_price, (tuple, list)):
|
elif isinstance(deal_price, (tuple, list)):
|
||||||
self.buy_price, self.sell_price = deal_price
|
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"This type of input is not supported")
|
raise NotImplementedError(f"This type of input is not supported")
|
||||||
|
|
||||||
@@ -167,10 +167,10 @@ class Exchange:
|
|||||||
|
|
||||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||||
if self.limit_type == self.LT_TP_EXP:
|
if self.limit_type == self.LT_TP_EXP:
|
||||||
|
assert isinstance(limit_threshold, tuple)
|
||||||
for exp in limit_threshold:
|
for exp in limit_threshold:
|
||||||
necessary_fields.add(exp)
|
necessary_fields.add(exp)
|
||||||
all_fields = necessary_fields | set(vol_lt_fields)
|
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||||
all_fields = list(all_fields | set(subscribe_fields))
|
|
||||||
|
|
||||||
self.all_fields = all_fields
|
self.all_fields = all_fields
|
||||||
|
|
||||||
@@ -249,9 +249,9 @@ class Exchange:
|
|||||||
LT_FLT = "float" # float
|
LT_FLT = "float" # float
|
||||||
LT_NONE = "none" # none
|
LT_NONE = "none" # none
|
||||||
|
|
||||||
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
|
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||||
"""get limit type"""
|
"""get limit type"""
|
||||||
if isinstance(limit_threshold, Tuple):
|
if isinstance(limit_threshold, tuple):
|
||||||
return self.LT_TP_EXP
|
return self.LT_TP_EXP
|
||||||
elif isinstance(limit_threshold, float):
|
elif isinstance(limit_threshold, float):
|
||||||
return self.LT_FLT
|
return self.LT_FLT
|
||||||
@@ -268,14 +268,16 @@ class Exchange:
|
|||||||
self.quote_df["limit_sell"] = False
|
self.quote_df["limit_sell"] = False
|
||||||
elif limit_type == self.LT_TP_EXP:
|
elif limit_type == self.LT_TP_EXP:
|
||||||
# set limit
|
# set limit
|
||||||
|
limit_threshold = cast(tuple, limit_threshold)
|
||||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||||
elif limit_type == self.LT_FLT:
|
elif limit_type == self.LT_FLT:
|
||||||
|
limit_threshold = cast(float, limit_threshold)
|
||||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
|
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||||
"""
|
"""
|
||||||
preprocess the volume limit.
|
preprocess the volume limit.
|
||||||
get the fields need to get from qlib.
|
get the fields need to get from qlib.
|
||||||
@@ -340,11 +342,11 @@ class Exchange:
|
|||||||
if direction is None:
|
if direction is None:
|
||||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||||
return buy_limit or sell_limit
|
return bool(buy_limit or sell_limit)
|
||||||
elif direction == Order.BUY:
|
elif direction == Order.BUY:
|
||||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
|
||||||
elif direction == Order.SELL:
|
elif direction == Order.SELL:
|
||||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"direction {direction} is not supported!")
|
raise ValueError(f"direction {direction} is not supported!")
|
||||||
|
|
||||||
@@ -382,7 +384,7 @@ class Exchange:
|
|||||||
order: Order,
|
order: Order,
|
||||||
trade_account: Account = None,
|
trade_account: Account = None,
|
||||||
position: BasePosition = None,
|
position: BasePosition = None,
|
||||||
dealt_order_amount: defaultdict = defaultdict(float),
|
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||||
) -> Tuple[float, float, float]:
|
) -> Tuple[float, float, float]:
|
||||||
"""
|
"""
|
||||||
Deal order when the actual transaction
|
Deal order when the actual transaction
|
||||||
@@ -426,9 +428,10 @@ class Exchange:
|
|||||||
stock_id: str,
|
stock_id: str,
|
||||||
start_time: pd.Timestamp,
|
start_time: pd.Timestamp,
|
||||||
end_time: pd.Timestamp,
|
end_time: pd.Timestamp,
|
||||||
|
field: str,
|
||||||
method: str = "ts_data_last",
|
method: str = "ts_data_last",
|
||||||
) -> Union[None, int, float, bool, IndexData]:
|
) -> Union[None, int, float, bool, IndexData]:
|
||||||
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
|
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
|
||||||
|
|
||||||
def get_close(
|
def get_close(
|
||||||
self,
|
self,
|
||||||
@@ -444,10 +447,10 @@ class Exchange:
|
|||||||
stock_id: str,
|
stock_id: str,
|
||||||
start_time: pd.Timestamp,
|
start_time: pd.Timestamp,
|
||||||
end_time: pd.Timestamp,
|
end_time: pd.Timestamp,
|
||||||
method: str = "sum",
|
method: Optional[str] = "sum",
|
||||||
) -> float:
|
) -> float:
|
||||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
|
||||||
|
|
||||||
def get_deal_price(
|
def get_deal_price(
|
||||||
self,
|
self,
|
||||||
@@ -455,7 +458,7 @@ class Exchange:
|
|||||||
start_time: pd.Timestamp,
|
start_time: pd.Timestamp,
|
||||||
end_time: pd.Timestamp,
|
end_time: pd.Timestamp,
|
||||||
direction: OrderDir,
|
direction: OrderDir,
|
||||||
method: str = "ts_data_last",
|
method: Optional[str] = "ts_data_last",
|
||||||
) -> float:
|
) -> float:
|
||||||
if direction == OrderDir.SELL:
|
if direction == OrderDir.SELL:
|
||||||
pstr = self.sell_price
|
pstr = self.sell_price
|
||||||
@@ -469,7 +472,7 @@ class Exchange:
|
|||||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||||
self.logger.warning(f"setting deal_price to close price")
|
self.logger.warning(f"setting deal_price to close price")
|
||||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||||
return deal_price
|
return cast(float, deal_price)
|
||||||
|
|
||||||
def get_factor(
|
def get_factor(
|
||||||
self,
|
self,
|
||||||
@@ -544,7 +547,7 @@ class Exchange:
|
|||||||
)
|
)
|
||||||
return amount_dict
|
return amount_dict
|
||||||
|
|
||||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
|
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the real adjust deal amount when considering the trading unit
|
Calculate the real adjust deal amount when considering the trading unit
|
||||||
:param current_amount:
|
:param current_amount:
|
||||||
@@ -572,7 +575,7 @@ class Exchange:
|
|||||||
current_position: dict,
|
current_position: dict,
|
||||||
start_time: pd.Timestamp,
|
start_time: pd.Timestamp,
|
||||||
end_time: pd.Timestamp,
|
end_time: pd.Timestamp,
|
||||||
) -> list:
|
) -> List[Order]:
|
||||||
"""
|
"""
|
||||||
Note: some future information is used in this function
|
Note: some future information is used in this function
|
||||||
Parameter:
|
Parameter:
|
||||||
@@ -681,6 +684,7 @@ class Exchange:
|
|||||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||||
|
assert factor is not None
|
||||||
return factor
|
return factor
|
||||||
|
|
||||||
def get_amount_of_trade_unit(
|
def get_amount_of_trade_unit(
|
||||||
@@ -718,12 +722,12 @@ class Exchange:
|
|||||||
|
|
||||||
def round_amount_by_trade_unit(
|
def round_amount_by_trade_unit(
|
||||||
self,
|
self,
|
||||||
deal_amount,
|
deal_amount: float,
|
||||||
factor: float = None,
|
factor: float = None,
|
||||||
stock_id: str = None,
|
stock_id: str = None,
|
||||||
start_time=None,
|
start_time: pd.Timestamp = None,
|
||||||
end_time=None,
|
end_time: pd.Timestamp = None,
|
||||||
):
|
) -> float:
|
||||||
"""Parameter
|
"""Parameter
|
||||||
Please refer to the docs of get_amount_of_trade_unit
|
Please refer to the docs of get_amount_of_trade_unit
|
||||||
deal_amount : float, adjusted amount
|
deal_amount : float, adjusted amount
|
||||||
@@ -741,7 +745,7 @@ class Exchange:
|
|||||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||||
return deal_amount
|
return deal_amount
|
||||||
|
|
||||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
|
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
|
||||||
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
||||||
NOTE:
|
NOTE:
|
||||||
this function will change the order.deal_amount **inplace**
|
this function will change the order.deal_amount **inplace**
|
||||||
@@ -753,15 +757,12 @@ class Exchange:
|
|||||||
dealt_order_amount : dict
|
dealt_order_amount : dict
|
||||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||||
"""
|
"""
|
||||||
if order.direction == Order.BUY:
|
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
|
||||||
vol_limit = self.buy_vol_limit
|
|
||||||
elif order.direction == Order.SELL:
|
|
||||||
vol_limit = self.sell_vol_limit
|
|
||||||
|
|
||||||
if vol_limit is None:
|
if vol_limit is None:
|
||||||
return order.deal_amount
|
return order.deal_amount
|
||||||
|
|
||||||
vol_limit_num = []
|
vol_limit_num: List[float] = []
|
||||||
for limit in vol_limit:
|
for limit in vol_limit:
|
||||||
assert isinstance(limit, tuple)
|
assert isinstance(limit, tuple)
|
||||||
if limit[0] == "current":
|
if limit[0] == "current":
|
||||||
@@ -772,7 +773,7 @@ class Exchange:
|
|||||||
field=limit[1],
|
field=limit[1],
|
||||||
method="sum",
|
method="sum",
|
||||||
)
|
)
|
||||||
vol_limit_num.append(limit_value)
|
vol_limit_num.append(cast(float, limit_value))
|
||||||
elif limit[0] == "cum":
|
elif limit[0] == "cum":
|
||||||
limit_value = self.quote.get_data(
|
limit_value = self.quote.get_data(
|
||||||
order.stock_id,
|
order.stock_id,
|
||||||
@@ -790,12 +791,14 @@ class Exchange:
|
|||||||
if vol_limit_min < orig_deal_amount:
|
if vol_limit_min < orig_deal_amount:
|
||||||
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
||||||
|
|
||||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
return None
|
||||||
|
|
||||||
|
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
|
||||||
"""return the real order amount after cash limit for buying.
|
"""return the real order amount after cash limit for buying.
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
trade_price : float
|
trade_price : float
|
||||||
position : cash
|
cash : float
|
||||||
cost_ratio : float
|
cost_ratio : float
|
||||||
|
|
||||||
Return
|
Return
|
||||||
@@ -803,7 +806,7 @@ class Exchange:
|
|||||||
float
|
float
|
||||||
the real order amount after cash limit for buying.
|
the real order amount after cash limit for buying.
|
||||||
"""
|
"""
|
||||||
max_trade_amount = 0
|
max_trade_amount = 0.0
|
||||||
if cash >= self.min_cost:
|
if cash >= self.min_cost:
|
||||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||||
@@ -897,7 +900,7 @@ class Exchange:
|
|||||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("order type {} error".format(order.type))
|
raise NotImplementedError("order direction {} error".format(order.direction))
|
||||||
|
|
||||||
trade_val = order.deal_amount * trade_price
|
trade_val = order.deal_amount * trade_price
|
||||||
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import copy
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, List, Tuple, Union, cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy
|
|||||||
from ..utils import init_instance_by_config
|
from ..utils import init_instance_by_config
|
||||||
from .decision import BaseTradeDecision, Order
|
from .decision import BaseTradeDecision, Order
|
||||||
from .exchange import Exchange
|
from .exchange import Exchange
|
||||||
from .utils import (
|
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
|
||||||
BaseInfrastructure,
|
|
||||||
CommonInfrastructure,
|
|
||||||
LevelInfrastructure,
|
|
||||||
TradeCalendarManager,
|
|
||||||
get_start_end_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExecutor:
|
class BaseExecutor:
|
||||||
@@ -39,8 +33,8 @@ class BaseExecutor:
|
|||||||
track_data: bool = False,
|
track_data: bool = False,
|
||||||
trade_exchange: Exchange = None,
|
trade_exchange: Exchange = None,
|
||||||
common_infra: CommonInfrastructure = None,
|
common_infra: CommonInfrastructure = None,
|
||||||
settle_type=BasePosition.ST_NO, # TODO: add typehint
|
settle_type: str = BasePosition.ST_NO,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
@@ -127,10 +121,10 @@ class BaseExecutor:
|
|||||||
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
||||||
|
|
||||||
# record deal order amount in one day
|
# record deal order amount in one day
|
||||||
self.dealt_order_amount = defaultdict(float)
|
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
|
||||||
self.deal_day = None
|
self.deal_day = None
|
||||||
|
|
||||||
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
|
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
reset infrastructure for trading
|
reset infrastructure for trading
|
||||||
- reset trade_account
|
- reset trade_account
|
||||||
@@ -141,14 +135,15 @@ class BaseExecutor:
|
|||||||
self.common_infra.update(common_infra)
|
self.common_infra.update(common_infra)
|
||||||
|
|
||||||
if common_infra.has("trade_account"):
|
if common_infra.has("trade_account"):
|
||||||
if copy_trade_account:
|
|
||||||
# NOTE: there is a trick in the code.
|
# NOTE: there is a trick in the code.
|
||||||
# shallow copy is used instead of deepcopy.
|
# shallow copy is used instead of deepcopy.
|
||||||
# 1. So positions are shared
|
# 1. So positions are shared
|
||||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
self.trade_account: Account = (
|
||||||
else:
|
copy.copy(common_infra.get("trade_account"))
|
||||||
self.trade_account: Account = common_infra.get("trade_account")
|
if copy_trade_account
|
||||||
|
else common_infra.get("trade_account")
|
||||||
|
)
|
||||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -164,7 +159,7 @@ class BaseExecutor:
|
|||||||
"""
|
"""
|
||||||
return self.level_infra.get("trade_calendar")
|
return self.level_infra.get("trade_calendar")
|
||||||
|
|
||||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
|
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
- reset `start_time` and `end_time`, used in trade calendar
|
- reset `start_time` and `end_time`, used in trade calendar
|
||||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||||
@@ -200,20 +195,17 @@ class BaseExecutor:
|
|||||||
execute_result : List[object]
|
execute_result : List[object]
|
||||||
the executed result for trade decision
|
the executed result for trade decision
|
||||||
"""
|
"""
|
||||||
return_value = {}
|
return_value: dict = {}
|
||||||
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
||||||
pass
|
pass
|
||||||
return return_value.get("execute_result")
|
return cast(list, return_value.get("execute_result"))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _collect_data(
|
def _collect_data(
|
||||||
self,
|
self,
|
||||||
trade_decision: BaseTradeDecision,
|
trade_decision: BaseTradeDecision,
|
||||||
level: int = 0,
|
level: int = 0,
|
||||||
) -> Union[
|
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
|
||||||
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
|
|
||||||
Tuple[List[object], dict],
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Please refer to the doc of collect_data
|
Please refer to the doc of collect_data
|
||||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||||
@@ -235,7 +227,7 @@ class BaseExecutor:
|
|||||||
trade_decision: BaseTradeDecision,
|
trade_decision: BaseTradeDecision,
|
||||||
return_value: dict = None,
|
return_value: dict = None,
|
||||||
level: int = 0,
|
level: int = 0,
|
||||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
|
) -> Generator[Any, Any, List[object]]:
|
||||||
"""Generator for collecting the trade decision data for rl training
|
"""Generator for collecting the trade decision data for rl training
|
||||||
|
|
||||||
his function will make a step forward
|
his function will make a step forward
|
||||||
@@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor):
|
|||||||
skip_empty_decision: bool = True,
|
skip_empty_decision: bool = True,
|
||||||
align_range_limit: bool = True,
|
align_range_limit: bool = True,
|
||||||
common_infra: CommonInfrastructure = None,
|
common_infra: CommonInfrastructure = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
@@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor):
|
|||||||
self,
|
self,
|
||||||
trade_decision: BaseTradeDecision,
|
trade_decision: BaseTradeDecision,
|
||||||
level: int = 0,
|
level: int = 0,
|
||||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
|
) -> Generator[Any, Any, Tuple[List[object], dict]]:
|
||||||
execute_result = []
|
execute_result = []
|
||||||
inner_order_indicators = []
|
inner_order_indicators = []
|
||||||
decision_list = []
|
decision_list = []
|
||||||
@@ -493,7 +485,7 @@ class NestedExecutor(BaseExecutor):
|
|||||||
the execution result of inner task
|
the execution result of inner task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_all_executors(self) -> List[object]:
|
def get_all_executors(self) -> List[BaseExecutor]:
|
||||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||||
return [self, *self.inner_executor.get_all_executors()]
|
return [self, *self.inner_executor.get_all_executors()]
|
||||||
|
|
||||||
@@ -536,7 +528,7 @@ class SimulatorExecutor(BaseExecutor):
|
|||||||
track_data: bool = False,
|
track_data: bool = False,
|
||||||
common_infra: CommonInfrastructure = None,
|
common_infra: CommonInfrastructure = None,
|
||||||
trade_type: str = TT_SERIAL,
|
trade_type: str = TT_SERIAL,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
@@ -598,7 +590,7 @@ class SimulatorExecutor(BaseExecutor):
|
|||||||
|
|
||||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||||
execute_result = []
|
execute_result: list = []
|
||||||
|
|
||||||
for order in self._get_order_iterator(trade_decision):
|
for order in self._get_order_iterator(trade_decision):
|
||||||
# execute the order.
|
# execute the order.
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, Dict, Iterable, List, Text, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value
|
|||||||
|
|
||||||
|
|
||||||
class BaseQuote:
|
class BaseQuote:
|
||||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||||
|
|
||||||
def get_all_stock(self) -> Iterable:
|
def get_all_stock(self) -> Iterable:
|
||||||
@@ -39,7 +41,7 @@ class BaseQuote:
|
|||||||
start_time: Union[pd.Timestamp, str],
|
start_time: Union[pd.Timestamp, str],
|
||||||
end_time: Union[pd.Timestamp, str],
|
end_time: Union[pd.Timestamp, str],
|
||||||
field: Union[str],
|
field: Union[str],
|
||||||
method: Union[str, None] = None,
|
method: Optional[str] = None,
|
||||||
) -> Union[None, int, float, bool, IndexData]:
|
) -> Union[None, int, float, bool, IndexData]:
|
||||||
"""get the specific field of stock data during start time and end_time,
|
"""get the specific field of stock data during start time and end_time,
|
||||||
and apply method to the data.
|
and apply method to the data.
|
||||||
@@ -99,7 +101,7 @@ class BaseQuote:
|
|||||||
|
|
||||||
|
|
||||||
class PandasQuote(BaseQuote):
|
class PandasQuote(BaseQuote):
|
||||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||||
super().__init__(quote_df=quote_df, freq=freq)
|
super().__init__(quote_df=quote_df, freq=freq)
|
||||||
quote_dict = {}
|
quote_dict = {}
|
||||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||||
@@ -124,7 +126,7 @@ class PandasQuote(BaseQuote):
|
|||||||
|
|
||||||
|
|
||||||
class NumpyQuote(BaseQuote):
|
class NumpyQuote(BaseQuote):
|
||||||
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
|
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
|
||||||
"""NumpyQuote
|
"""NumpyQuote
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote):
|
|||||||
data = self._agg_data(data, method)
|
data = self._agg_data(data, method)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _agg_data(self, data: IndexData, method):
|
@staticmethod
|
||||||
|
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
|
||||||
"""Agg data by specific method."""
|
"""Agg data by specific method."""
|
||||||
# FIXME: why not call the method of data directly?
|
# FIXME: why not call the method of data directly?
|
||||||
if method == "sum":
|
if method == "sum":
|
||||||
@@ -224,31 +227,31 @@ class BaseSingleMetric:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `__init__` method")
|
raise NotImplementedError(f"Please implement the `__init__` method")
|
||||||
|
|
||||||
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__add__` method")
|
raise NotImplementedError(f"Please implement the `__add__` method")
|
||||||
|
|
||||||
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
return self + other
|
return self + other
|
||||||
|
|
||||||
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__sub__` method")
|
raise NotImplementedError(f"Please implement the `__sub__` method")
|
||||||
|
|
||||||
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__rsub__` method")
|
raise NotImplementedError(f"Please implement the `__rsub__` method")
|
||||||
|
|
||||||
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__mul__` method")
|
raise NotImplementedError(f"Please implement the `__mul__` method")
|
||||||
|
|
||||||
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__truediv__` method")
|
raise NotImplementedError(f"Please implement the `__truediv__` method")
|
||||||
|
|
||||||
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __eq__(self, other: object) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__eq__` method")
|
raise NotImplementedError(f"Please implement the `__eq__` method")
|
||||||
|
|
||||||
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__gt__` method")
|
raise NotImplementedError(f"Please implement the `__gt__` method")
|
||||||
|
|
||||||
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `__lt__` method")
|
raise NotImplementedError(f"Please implement the `__lt__` method")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -265,7 +268,7 @@ class BaseSingleMetric:
|
|||||||
|
|
||||||
raise NotImplementedError(f"Please implement the `count` method")
|
raise NotImplementedError(f"Please implement the `count` method")
|
||||||
|
|
||||||
def abs(self) -> "BaseSingleMetric":
|
def abs(self) -> BaseSingleMetric:
|
||||||
raise NotImplementedError(f"Please implement the `abs` method")
|
raise NotImplementedError(f"Please implement the `abs` method")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -274,17 +277,17 @@ class BaseSingleMetric:
|
|||||||
|
|
||||||
raise NotImplementedError(f"Please implement the `empty` method")
|
raise NotImplementedError(f"Please implement the `empty` method")
|
||||||
|
|
||||||
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
|
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||||
|
|
||||||
raise NotImplementedError(f"Please implement the `add` method")
|
raise NotImplementedError(f"Please implement the `add` method")
|
||||||
|
|
||||||
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
|
def replace(self, replace_dict: dict) -> BaseSingleMetric:
|
||||||
"""Replace the value of metric according to replace_dict."""
|
"""Replace the value of metric according to replace_dict."""
|
||||||
|
|
||||||
raise NotImplementedError(f"Please implement the `replace` method")
|
raise NotImplementedError(f"Please implement the `replace` method")
|
||||||
|
|
||||||
def apply(self, func: dict) -> "BaseSingleMetric":
|
def apply(self, func: Callable) -> BaseSingleMetric:
|
||||||
"""Replace the value of metric with func (metric).
|
"""Replace the value of metric with func (metric).
|
||||||
Currently, the func is only qlib/backtest/order/Order.parse_dir.
|
Currently, the func is only qlib/backtest/order/Order.parse_dir.
|
||||||
"""
|
"""
|
||||||
@@ -304,11 +307,11 @@ class BaseOrderIndicator:
|
|||||||
to inherit the BaseSingleMetric.
|
to inherit the BaseSingleMetric.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self):
|
||||||
self.data = data
|
self.data = {} # will be created in the subclass
|
||||||
self.logger = get_module_logger("online operator")
|
self.logger = get_module_logger("online operator")
|
||||||
|
|
||||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||||
"""assign one metric.
|
"""assign one metric.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -328,7 +331,7 @@ class BaseOrderIndicator:
|
|||||||
|
|
||||||
raise NotImplementedError(f"Please implement the 'assign' method")
|
raise NotImplementedError(f"Please implement the 'assign' method")
|
||||||
|
|
||||||
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
|
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
|
||||||
"""compute new metric with existing metrics.
|
"""compute new metric with existing metrics.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -352,6 +355,7 @@ class BaseOrderIndicator:
|
|||||||
tmp_metric = func(**func_kwargs)
|
tmp_metric = func(**func_kwargs)
|
||||||
if new_col is not None:
|
if new_col is not None:
|
||||||
self.data[new_col] = tmp_metric
|
self.data[new_col] = tmp_metric
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return tmp_metric
|
return tmp_metric
|
||||||
|
|
||||||
@@ -372,7 +376,7 @@ class BaseOrderIndicator:
|
|||||||
|
|
||||||
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
|
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
|
||||||
|
|
||||||
def get_index_data(self, metric) -> SingleData:
|
def get_index_data(self, metric: str) -> SingleData:
|
||||||
"""get one metric with the format of SingleData
|
"""get one metric with the format of SingleData
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -389,7 +393,12 @@ class BaseOrderIndicator:
|
|||||||
raise NotImplementedError(f"Please implement the 'get_index_data' method")
|
raise NotImplementedError(f"Please implement the 'get_index_data' method")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
def sum_all_indicators(
|
||||||
|
order_indicator: BaseOrderIndicator,
|
||||||
|
indicators: List[BaseOrderIndicator],
|
||||||
|
metrics: Union[str, List[str]],
|
||||||
|
fill_value: float = 0,
|
||||||
|
) -> None:
|
||||||
"""sum indicators with the same metrics.
|
"""sum indicators with the same metrics.
|
||||||
and assign to the order_indicator(BaseOrderIndicator).
|
and assign to the order_indicator(BaseOrderIndicator).
|
||||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||||
@@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
|
|||||||
def index(self):
|
def index(self):
|
||||||
return list(self.metric.index)
|
return list(self.metric.index)
|
||||||
|
|
||||||
def add(self, other, fill_value=None):
|
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
|
||||||
|
other = cast(PandasSingleMetric, other)
|
||||||
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
|
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
|
||||||
|
|
||||||
def replace(self, replace_dict: dict):
|
def replace(self, replace_dict: dict) -> PandasSingleMetric:
|
||||||
return self.__class__(self.metric.replace(replace_dict))
|
return self.__class__(self.metric.replace(replace_dict))
|
||||||
|
|
||||||
def apply(self, func: Callable):
|
def apply(self, func: Callable) -> PandasSingleMetric:
|
||||||
return self.__class__(self.metric.apply(func))
|
return self.__class__(self.metric.apply(func))
|
||||||
|
|
||||||
def reindex(self, index, fill_value):
|
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
|
||||||
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
|||||||
Str is the name of metric.
|
Str is the name of metric.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
|
super(PandasOrderIndicator, self).__init__()
|
||||||
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
|
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
|
||||||
|
|
||||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||||
self.data[col] = PandasSingleMetric(metric)
|
self.data[col] = PandasSingleMetric(metric)
|
||||||
|
|
||||||
def get_index_data(self, metric):
|
def get_index_data(self, metric: str) -> SingleData:
|
||||||
if metric in self.data:
|
if metric in self.data:
|
||||||
return idd.SingleData(self.data[metric].metric)
|
return idd.SingleData(self.data[metric].metric)
|
||||||
else:
|
else:
|
||||||
@@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
|||||||
return {k: v.metric for k, v in self.data.items()}
|
return {k: v.metric for k, v in self.data.items()}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
def sum_all_indicators(
|
||||||
|
order_indicator: BaseOrderIndicator,
|
||||||
|
indicators: List[BaseOrderIndicator],
|
||||||
|
metrics: Union[str, List[str]],
|
||||||
|
fill_value: float = 0,
|
||||||
|
) -> None:
|
||||||
if isinstance(metrics, str):
|
if isinstance(metrics, str):
|
||||||
metrics = [metrics]
|
metrics = [metrics]
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
@@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
|||||||
Str is the name of metric.
|
Str is the name of metric.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
|
super(NumpyOrderIndicator, self).__init__()
|
||||||
self.data: Dict[str, SingleData] = OrderedDict()
|
self.data: Dict[str, SingleData] = OrderedDict()
|
||||||
|
|
||||||
def assign(self, col: str, metric: dict):
|
def assign(self, col: str, metric: dict) -> None:
|
||||||
self.data[col] = idd.SingleData(metric)
|
self.data[col] = idd.SingleData(metric)
|
||||||
|
|
||||||
def get_index_data(self, metric):
|
def get_index_data(self, metric: str) -> SingleData:
|
||||||
if metric in self.data:
|
if metric in self.data:
|
||||||
return self.data[metric]
|
return self.data[metric]
|
||||||
else:
|
else:
|
||||||
@@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
|||||||
return tmp_metric_dict
|
return tmp_metric_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
def sum_all_indicators(
|
||||||
|
order_indicator: BaseOrderIndicator,
|
||||||
|
indicators: List[BaseOrderIndicator],
|
||||||
|
metrics: Union[str, List[str]],
|
||||||
|
fill_value: float = 0,
|
||||||
|
) -> None:
|
||||||
# get all index(stock_id)
|
# get all index(stock_id)
|
||||||
stocks = set()
|
stock_set: set = set()
|
||||||
for indicator in indicators:
|
for indicator in indicators:
|
||||||
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
||||||
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
|
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
|
||||||
stocks = list(stocks)
|
stocks = sorted(list(stock_set))
|
||||||
stocks.sort()
|
|
||||||
|
|
||||||
# add metric by index
|
# add metric by index
|
||||||
if isinstance(metrics, str):
|
if isinstance(metrics, str):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -18,9 +18,9 @@ class BasePosition:
|
|||||||
Please refer to the `Position` class for the position
|
Please refer to the `Position` class for the position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
|
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
|
||||||
self._settle_type = self.ST_NO
|
self._settle_type = self.ST_NO
|
||||||
self.position = {}
|
self.position: dict = {}
|
||||||
|
|
||||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||||
pass
|
pass
|
||||||
@@ -96,13 +96,13 @@ class BasePosition:
|
|||||||
def calculate_value(self) -> float:
|
def calculate_value(self) -> float:
|
||||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||||
|
|
||||||
def get_stock_list(self) -> List:
|
def get_stock_list(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the list of stocks in the position.
|
Get the list of stocks in the position.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||||
|
|
||||||
def get_stock_price(self, code) -> float:
|
def get_stock_price(self, code: str) -> float:
|
||||||
"""
|
"""
|
||||||
get the latest price of the stock
|
get the latest price of the stock
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class BasePosition:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||||
|
|
||||||
def get_stock_amount(self, code) -> float:
|
def get_stock_amount(self, code: str) -> float:
|
||||||
"""
|
"""
|
||||||
get the amount of the stock
|
get the amount of the stock
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ class BasePosition:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||||
|
|
||||||
def get_stock_amount_dict(self) -> Dict:
|
def get_stock_amount_dict(self) -> dict:
|
||||||
"""
|
"""
|
||||||
generate stock amount dict {stock_id : amount of stock}
|
generate stock amount dict {stock_id : amount of stock}
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ class BasePosition:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||||
|
|
||||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||||
"""
|
"""
|
||||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||||
it is meaningful in the beginning or the end of each trade step
|
it is meaningful in the beginning or the end of each trade step
|
||||||
@@ -174,7 +174,7 @@ class BasePosition:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||||
|
|
||||||
def add_count_all(self, bar) -> None:
|
def add_count_all(self, bar: str) -> None:
|
||||||
"""
|
"""
|
||||||
Will be called at the end of each bar on each level
|
Will be called at the end of each bar on each level
|
||||||
|
|
||||||
@@ -195,7 +195,7 @@ class BasePosition:
|
|||||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||||
|
|
||||||
ST_CASH = "cash"
|
ST_CASH = "cash"
|
||||||
ST_NO = None
|
ST_NO = "None" # String is more typehint friendly than None
|
||||||
|
|
||||||
def settle_start(self, settle_type: str) -> None:
|
def settle_start(self, settle_type: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -220,10 +220,10 @@ class BasePosition:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return self.__dict__.__str__()
|
return self.__dict__.__str__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return self.__dict__.__repr__()
|
return self.__dict__.__repr__()
|
||||||
|
|
||||||
|
|
||||||
@@ -532,7 +532,7 @@ class InfPosition(BasePosition):
|
|||||||
def calculate_value(self) -> float:
|
def calculate_value(self) -> float:
|
||||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||||
|
|
||||||
def get_stock_list(self) -> list:
|
def get_stock_list(self) -> List[str]:
|
||||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||||
|
|
||||||
def get_stock_price(self, code: str) -> float:
|
def get_stock_price(self, code: str) -> float:
|
||||||
@@ -545,10 +545,10 @@ class InfPosition(BasePosition):
|
|||||||
def get_cash(self, include_settle: bool = False) -> float:
|
def get_cash(self, include_settle: bool = False) -> float:
|
||||||
return np.inf
|
return np.inf
|
||||||
|
|
||||||
def get_stock_amount_dict(self) -> Dict:
|
def get_stock_amount_dict(self) -> dict:
|
||||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||||
|
|
||||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||||
|
|
||||||
def add_count_all(self, bar: str) -> None:
|
def add_count_all(self, bar: str) -> None:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange
|
|||||||
|
|
||||||
from ..tests.config import CSI300_BENCH
|
from ..tests.config import CSI300_BENCH
|
||||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
|
||||||
|
|
||||||
|
|
||||||
class PortfolioMetrics:
|
class PortfolioMetrics:
|
||||||
@@ -38,7 +38,7 @@ class PortfolioMetrics:
|
|||||||
update report
|
update report
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -49,13 +49,17 @@ class PortfolioMetrics:
|
|||||||
- benchmark : Union[str, list, pd.Series]
|
- benchmark : Union[str, list, pd.Series]
|
||||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||||
example:
|
example:
|
||||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
print(
|
||||||
|
D.features(D.instruments('csi500'),
|
||||||
|
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
|
||||||
|
)
|
||||||
2017-01-04 0.011693
|
2017-01-04 0.011693
|
||||||
2017-01-05 0.000721
|
2017-01-05 0.000721
|
||||||
2017-01-06 -0.004322
|
2017-01-06 -0.004322
|
||||||
2017-01-09 0.006874
|
2017-01-09 0.006874
|
||||||
2017-01-10 -0.003350
|
2017-01-10 -0.003350
|
||||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
|
||||||
|
'bench'.
|
||||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||||
benchmark code, default is SH000300 CSI300
|
benchmark code, default is SH000300 CSI300
|
||||||
- start_time : Union[str, pd.Timestamp], optional
|
- start_time : Union[str, pd.Timestamp], optional
|
||||||
@@ -70,25 +74,26 @@ class PortfolioMetrics:
|
|||||||
self.init_vars()
|
self.init_vars()
|
||||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||||
|
|
||||||
def init_vars(self):
|
def init_vars(self) -> None:
|
||||||
self.accounts = OrderedDict() # account position value for each trade time
|
self.accounts: dict = OrderedDict() # account position value for each trade time
|
||||||
self.returns = OrderedDict() # daily return rate for each trade time
|
self.returns: dict = OrderedDict() # daily return rate for each trade time
|
||||||
self.total_turnovers = OrderedDict() # total turnover for each trade time
|
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
|
||||||
self.turnovers = OrderedDict() # turnover for each trade time
|
self.turnovers: dict = OrderedDict() # turnover for each trade time
|
||||||
self.total_costs = OrderedDict() # total trade cost for each trade time
|
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
|
||||||
self.costs = OrderedDict() # trade cost rate for each trade time
|
self.costs: dict = OrderedDict() # trade cost rate for each trade time
|
||||||
self.values = OrderedDict() # value for each trade time
|
self.values: dict = OrderedDict() # value for each trade time
|
||||||
self.cashes = OrderedDict()
|
self.cashes: dict = OrderedDict()
|
||||||
self.benches = OrderedDict()
|
self.benches: dict = OrderedDict()
|
||||||
self.latest_pm_time = None # pd.TimeStamp
|
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||||
|
|
||||||
def init_bench(self, freq=None, benchmark_config=None):
|
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||||
if freq is not None:
|
if freq is not None:
|
||||||
self.freq = freq
|
self.freq = freq
|
||||||
self.benchmark_config = benchmark_config
|
self.benchmark_config = benchmark_config
|
||||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||||
|
|
||||||
def _cal_benchmark(self, benchmark_config, freq):
|
@staticmethod
|
||||||
|
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
|
||||||
if benchmark_config is None:
|
if benchmark_config is None:
|
||||||
return None
|
return None
|
||||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||||
@@ -110,7 +115,12 @@ class PortfolioMetrics:
|
|||||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||||
|
|
||||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
def _sample_benchmark(
|
||||||
|
self,
|
||||||
|
bench: pd.Series,
|
||||||
|
trade_start_time: Union[str, pd.Timestamp],
|
||||||
|
trade_end_time: Union[str, pd.Timestamp],
|
||||||
|
) -> Optional[float]:
|
||||||
if self.bench is None:
|
if self.bench is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -120,35 +130,35 @@ class PortfolioMetrics:
|
|||||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||||
return 0.0 if _ret is None else _ret - 1
|
return 0.0 if _ret is None else _ret - 1
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self) -> bool:
|
||||||
return len(self.accounts) == 0
|
return len(self.accounts) == 0
|
||||||
|
|
||||||
def get_latest_date(self):
|
def get_latest_date(self) -> pd.Timestamp:
|
||||||
return self.latest_pm_time
|
return self.latest_pm_time
|
||||||
|
|
||||||
def get_latest_account_value(self):
|
def get_latest_account_value(self) -> float:
|
||||||
return self.accounts[self.latest_pm_time]
|
return self.accounts[self.latest_pm_time]
|
||||||
|
|
||||||
def get_latest_total_cost(self):
|
def get_latest_total_cost(self) -> Any:
|
||||||
return self.total_costs[self.latest_pm_time]
|
return self.total_costs[self.latest_pm_time]
|
||||||
|
|
||||||
def get_latest_total_turnover(self):
|
def get_latest_total_turnover(self) -> Any:
|
||||||
return self.total_turnovers[self.latest_pm_time]
|
return self.total_turnovers[self.latest_pm_time]
|
||||||
|
|
||||||
def update_portfolio_metrics_record(
|
def update_portfolio_metrics_record(
|
||||||
self,
|
self,
|
||||||
trade_start_time=None,
|
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||||
trade_end_time=None,
|
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||||
account_value=None,
|
account_value: float = None,
|
||||||
cash=None,
|
cash: float = None,
|
||||||
return_rate=None,
|
return_rate: float = None,
|
||||||
total_turnover=None,
|
total_turnover: float = None,
|
||||||
turnover_rate=None,
|
turnover_rate: float = None,
|
||||||
total_cost=None,
|
total_cost: float = None,
|
||||||
cost_rate=None,
|
cost_rate: float = None,
|
||||||
stock_value=None,
|
stock_value: float = None,
|
||||||
bench_value=None,
|
bench_value: float = None,
|
||||||
):
|
) -> None:
|
||||||
# check data
|
# check data
|
||||||
if None in [
|
if None in [
|
||||||
trade_start_time,
|
trade_start_time,
|
||||||
@@ -185,7 +195,7 @@ class PortfolioMetrics:
|
|||||||
self.latest_pm_time = trade_start_time
|
self.latest_pm_time = trade_start_time
|
||||||
# finish pm update in each step
|
# finish pm update in each step
|
||||||
|
|
||||||
def generate_portfolio_metrics_dataframe(self):
|
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
|
||||||
pm = pd.DataFrame()
|
pm = pd.DataFrame()
|
||||||
pm["account"] = pd.Series(self.accounts)
|
pm["account"] = pd.Series(self.accounts)
|
||||||
pm["return"] = pd.Series(self.returns)
|
pm["return"] = pd.Series(self.returns)
|
||||||
@@ -199,19 +209,18 @@ class PortfolioMetrics:
|
|||||||
pm.index.name = "datetime"
|
pm.index.name = "datetime"
|
||||||
return pm
|
return pm
|
||||||
|
|
||||||
def save_portfolio_metrics(self, path):
|
def save_portfolio_metrics(self, path: str) -> None:
|
||||||
r = self.generate_portfolio_metrics_dataframe()
|
r = self.generate_portfolio_metrics_dataframe()
|
||||||
r.to_csv(path)
|
r.to_csv(path)
|
||||||
|
|
||||||
def load_portfolio_metrics(self, path):
|
def load_portfolio_metrics(self, path: str) -> None:
|
||||||
"""load pm from a file
|
"""load pm from a file
|
||||||
should have format like
|
should have format like
|
||||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||||
:param
|
:param
|
||||||
path: str/ pathlib.Path()
|
path: str/ pathlib.Path()
|
||||||
"""
|
"""
|
||||||
path = pathlib.Path(path)
|
with pathlib.Path(path).open("rb") as f:
|
||||||
with path.open("rb") as f:
|
|
||||||
r = pd.read_csv(f, index_col=0)
|
r = pd.read_csv(f, index_col=0)
|
||||||
r.index = pd.DatetimeIndex(r.index)
|
r.index = pd.DatetimeIndex(r.index)
|
||||||
|
|
||||||
@@ -261,30 +270,30 @@ class Indicator:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
|
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
|
||||||
self.order_indicator_cls = order_indicator_cls
|
self.order_indicator_cls = order_indicator_cls
|
||||||
|
|
||||||
# order indicator is metrics for a single order for a specific step
|
# order indicator is metrics for a single order for a specific step
|
||||||
self.order_indicator_his = OrderedDict()
|
self.order_indicator_his: dict = OrderedDict()
|
||||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||||
|
|
||||||
# trade indicator is metrics for all orders for a specific step
|
# trade indicator is metrics for all orders for a specific step
|
||||||
self.trade_indicator_his = OrderedDict()
|
self.trade_indicator_his: dict = OrderedDict()
|
||||||
self.trade_indicator: Dict[str, float] = OrderedDict()
|
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
|
||||||
|
|
||||||
self._trade_calendar = None
|
self._trade_calendar = None
|
||||||
|
|
||||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
self.order_indicator = self.order_indicator_cls()
|
||||||
self.trade_indicator = OrderedDict()
|
self.trade_indicator = OrderedDict()
|
||||||
# self._trade_calendar = trade_calendar
|
# self._trade_calendar = trade_calendar
|
||||||
|
|
||||||
def record(self, trade_start_time):
|
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
|
||||||
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
||||||
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
||||||
|
|
||||||
def _update_order_trade_info(self, trade_info: list):
|
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||||
amount = dict()
|
amount = dict()
|
||||||
deal_amount = dict()
|
deal_amount = dict()
|
||||||
trade_price = dict()
|
trade_price = dict()
|
||||||
@@ -313,7 +322,7 @@ class Indicator:
|
|||||||
self.order_indicator.assign("trade_dir", trade_dir)
|
self.order_indicator.assign("trade_dir", trade_dir)
|
||||||
self.order_indicator.assign("pa", pa)
|
self.order_indicator.assign("pa", pa)
|
||||||
|
|
||||||
def _update_order_fulfill_rate(self):
|
def _update_order_fulfill_rate(self) -> None:
|
||||||
def func(deal_amount, amount):
|
def func(deal_amount, amount):
|
||||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||||
@@ -322,11 +331,11 @@ class Indicator:
|
|||||||
|
|
||||||
self.order_indicator.transfer(func, "ffr")
|
self.order_indicator.transfer(func, "ffr")
|
||||||
|
|
||||||
def update_order_indicators(self, trade_info: list):
|
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||||
self._update_order_trade_info(trade_info=trade_info)
|
self._update_order_trade_info(trade_info=trade_info)
|
||||||
self._update_order_fulfill_rate()
|
self._update_order_fulfill_rate()
|
||||||
|
|
||||||
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
|
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
|
||||||
# calculate total trade amount with each inner order indicator.
|
# calculate total trade amount with each inner order indicator.
|
||||||
def trade_amount_func(deal_amount, trade_price):
|
def trade_amount_func(deal_amount, trade_price):
|
||||||
return deal_amount * trade_price
|
return deal_amount * trade_price
|
||||||
@@ -355,9 +364,9 @@ class Indicator:
|
|||||||
|
|
||||||
self.order_indicator.transfer(func_apply, "trade_dir")
|
self.order_indicator.transfer(func_apply, "trade_dir")
|
||||||
|
|
||||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
|
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
|
||||||
# NOTE: these indicator is designed for order execution, so the
|
# NOTE: these indicator is designed for order execution, so the
|
||||||
decision: List[Order] = outer_trade_decision.get_decision()
|
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
|
||||||
if len(decision) == 0:
|
if len(decision) == 0:
|
||||||
self.order_indicator.assign("amount", {})
|
self.order_indicator.assign("amount", {})
|
||||||
else:
|
else:
|
||||||
@@ -372,7 +381,7 @@ class Indicator:
|
|||||||
decision: BaseTradeDecision,
|
decision: BaseTradeDecision,
|
||||||
trade_exchange: Exchange,
|
trade_exchange: Exchange,
|
||||||
pa_config: dict = {},
|
pa_config: dict = {},
|
||||||
):
|
) -> Tuple[Optional[float], Optional[float]]:
|
||||||
"""
|
"""
|
||||||
Get the base volume and price information
|
Get the base volume and price information
|
||||||
All the base price values are rooted from this function
|
All the base price values are rooted from this function
|
||||||
@@ -412,31 +421,35 @@ class Indicator:
|
|||||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||||
# for aligning the previous logic, remove it.
|
# for aligning the previous logic, remove it.
|
||||||
# remove zero and negative values.
|
# remove zero and negative values.
|
||||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
|
assert isinstance(price_s, idd.SingleData)
|
||||||
|
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||||
|
|
||||||
|
assert isinstance(price_s, idd.SingleData)
|
||||||
if agg == "vwap":
|
if agg == "vwap":
|
||||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||||
if isinstance(volume_s, (int, float, np.number)):
|
if isinstance(volume_s, (int, float, np.number)):
|
||||||
volume_s = idd.SingleData(volume_s, [trade_start_time])
|
volume_s = idd.SingleData(volume_s, [trade_start_time])
|
||||||
|
assert isinstance(volume_s, idd.SingleData)
|
||||||
volume_s = volume_s.reindex(price_s.index)
|
volume_s = volume_s.reindex(price_s.index)
|
||||||
elif agg == "twap":
|
elif agg == "twap":
|
||||||
volume_s = idd.SingleData(1, price_s.index)
|
volume_s = idd.SingleData(1, price_s.index)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"This type of input is not supported")
|
raise NotImplementedError(f"This type of input is not supported")
|
||||||
|
|
||||||
|
assert isinstance(volume_s, idd.SingleData)
|
||||||
base_volume = volume_s.sum()
|
base_volume = volume_s.sum()
|
||||||
base_price = (price_s * volume_s).sum() / base_volume
|
base_price = (price_s * volume_s).sum() / base_volume
|
||||||
return base_price, base_volume
|
return base_price, base_volume
|
||||||
|
|
||||||
def _agg_base_price(
|
def _agg_base_price(
|
||||||
self,
|
self,
|
||||||
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
|
inner_order_indicators: List[BaseOrderIndicator],
|
||||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||||
trade_exchange: Exchange,
|
trade_exchange: Exchange,
|
||||||
pa_config: dict = {},
|
pa_config: dict = {},
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
# NOTE:!!!!
|
# NOTE:!!!!
|
||||||
# Strong assumption!!!!!!
|
# Strong assumption!!!!!!
|
||||||
@@ -444,7 +457,7 @@ class Indicator:
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inner_order_indicators : List[Dict[str, pd.Series]]
|
inner_order_indicators : List[BaseOrderIndicator]
|
||||||
the indicators of account of inner executor
|
the indicators of account of inner executor
|
||||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||||
a list of decisions according to inner_order_indicators
|
a list of decisions according to inner_order_indicators
|
||||||
@@ -489,14 +502,17 @@ class Indicator:
|
|||||||
bv_new = idd.SingleData(bv_new)
|
bv_new = idd.SingleData(bv_new)
|
||||||
bp_all.append(bp_new)
|
bp_all.append(bp_new)
|
||||||
bv_all.append(bv_new)
|
bv_all.append(bv_new)
|
||||||
bp_all = idd.concat(bp_all, axis=1)
|
bp_all_multi_data = idd.concat(bp_all, axis=1)
|
||||||
bv_all = idd.concat(bv_all, axis=1)
|
bv_all_multi_data = idd.concat(bv_all, axis=1)
|
||||||
|
|
||||||
base_volume = bv_all.sum(axis=1)
|
base_volume = bv_all_multi_data.sum(axis=1)
|
||||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
self.order_indicator.assign(
|
||||||
|
"base_price",
|
||||||
|
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
def _agg_order_price_advantage(self):
|
def _agg_order_price_advantage(self) -> None:
|
||||||
def if_empty_func(trade_price):
|
def if_empty_func(trade_price):
|
||||||
return trade_price.empty
|
return trade_price.empty
|
||||||
|
|
||||||
@@ -513,12 +529,12 @@ class Indicator:
|
|||||||
|
|
||||||
def agg_order_indicators(
|
def agg_order_indicators(
|
||||||
self,
|
self,
|
||||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
inner_order_indicators: List[BaseOrderIndicator],
|
||||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||||
outer_trade_decision: BaseTradeDecision,
|
outer_trade_decision: BaseTradeDecision,
|
||||||
trade_exchange: Exchange,
|
trade_exchange: Exchange,
|
||||||
indicator_config={},
|
indicator_config: dict = {},
|
||||||
):
|
) -> None:
|
||||||
self._agg_order_trade_info(inner_order_indicators)
|
self._agg_order_trade_info(inner_order_indicators)
|
||||||
self._update_trade_amount(outer_trade_decision)
|
self._update_trade_amount(outer_trade_decision)
|
||||||
self._update_order_fulfill_rate()
|
self._update_order_fulfill_rate()
|
||||||
@@ -526,71 +542,66 @@ class Indicator:
|
|||||||
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
|
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
|
||||||
self._agg_order_price_advantage()
|
self._agg_order_price_advantage()
|
||||||
|
|
||||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||||
if method == "mean":
|
if method == "mean":
|
||||||
|
return self.order_indicator.transfer(
|
||||||
def func(ffr):
|
lambda ffr: ffr.mean(),
|
||||||
return ffr.mean()
|
)
|
||||||
|
|
||||||
elif method == "amount_weighted":
|
elif method == "amount_weighted":
|
||||||
|
return self.order_indicator.transfer(
|
||||||
def func(ffr, deal_amount):
|
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||||
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
)
|
||||||
|
|
||||||
elif method == "value_weighted":
|
elif method == "value_weighted":
|
||||||
|
return self.order_indicator.transfer(
|
||||||
def func(ffr, trade_value):
|
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||||
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"method {method} is not supported!")
|
raise ValueError(f"method {method} is not supported!")
|
||||||
return self.order_indicator.transfer(func)
|
|
||||||
|
|
||||||
def _cal_trade_price_advantage(self, method="mean"):
|
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||||
if method == "mean":
|
if method == "mean":
|
||||||
|
return self.order_indicator.transfer(lambda pa: pa.mean())
|
||||||
def func(pa):
|
|
||||||
return pa.mean()
|
|
||||||
|
|
||||||
elif method == "amount_weighted":
|
elif method == "amount_weighted":
|
||||||
|
return self.order_indicator.transfer(
|
||||||
def func(pa, deal_amount):
|
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||||
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
)
|
||||||
|
|
||||||
elif method == "value_weighted":
|
elif method == "value_weighted":
|
||||||
|
return self.order_indicator.transfer(
|
||||||
def func(pa, trade_value):
|
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||||
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"method {method} is not supported!")
|
raise ValueError(f"method {method} is not supported!")
|
||||||
return self.order_indicator.transfer(func)
|
|
||||||
|
|
||||||
def _cal_trade_positive_rate(self):
|
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
|
||||||
def func(pa):
|
def func(pa):
|
||||||
return (pa > 0).sum() / pa.count()
|
return (pa > 0).sum() / pa.count()
|
||||||
|
|
||||||
return self.order_indicator.transfer(func)
|
return self.order_indicator.transfer(func)
|
||||||
|
|
||||||
def _cal_deal_amount(self):
|
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
|
||||||
def func(deal_amount):
|
def func(deal_amount):
|
||||||
return deal_amount.abs().sum()
|
return deal_amount.abs().sum()
|
||||||
|
|
||||||
return self.order_indicator.transfer(func)
|
return self.order_indicator.transfer(func)
|
||||||
|
|
||||||
def _cal_trade_value(self):
|
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
|
||||||
def func(trade_value):
|
def func(trade_value):
|
||||||
return trade_value.abs().sum()
|
return trade_value.abs().sum()
|
||||||
|
|
||||||
return self.order_indicator.transfer(func)
|
return self.order_indicator.transfer(func)
|
||||||
|
|
||||||
def _cal_trade_order_count(self):
|
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
|
||||||
def func(amount):
|
def func(amount):
|
||||||
return amount.count()
|
return amount.count()
|
||||||
|
|
||||||
return self.order_indicator.transfer(func)
|
return self.order_indicator.transfer(func)
|
||||||
|
|
||||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
def cal_trade_indicators(
|
||||||
|
self,
|
||||||
|
trade_start_time: Union[str, pd.Timestamp],
|
||||||
|
freq: str,
|
||||||
|
indicator_config: dict = {},
|
||||||
|
) -> None:
|
||||||
show_indicator = indicator_config.get("show_indicator", False)
|
show_indicator = indicator_config.get("show_indicator", False)
|
||||||
ffr_config = indicator_config.get("ffr_config", {})
|
ffr_config = indicator_config.get("ffr_config", {})
|
||||||
pa_config = indicator_config.get("pa_config", {})
|
pa_config = indicator_config.get("pa_config", {})
|
||||||
@@ -608,22 +619,22 @@ class Indicator:
|
|||||||
self.trade_indicator["count"] = order_count
|
self.trade_indicator["count"] = order_count
|
||||||
if show_indicator:
|
if show_indicator:
|
||||||
print(
|
print(
|
||||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||||
freq,
|
freq,
|
||||||
trade_start_time,
|
trade_start_time
|
||||||
|
if isinstance(trade_start_time, str)
|
||||||
|
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
fulfill_rate,
|
fulfill_rate,
|
||||||
price_advantage,
|
price_advantage,
|
||||||
positive_rate,
|
positive_rate,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_order_indicator(self, raw: bool = True):
|
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
|
||||||
if raw:
|
return self.order_indicator if raw else self.order_indicator.to_series()
|
||||||
return self.order_indicator
|
|
||||||
return self.order_indicator.to_series()
|
|
||||||
|
|
||||||
def get_trade_indicator(self):
|
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
|
||||||
return self.trade_indicator
|
return self.trade_indicator
|
||||||
|
|
||||||
def generate_trade_indicators_dataframe(self):
|
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
|
||||||
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
|
||||||
"""
|
"""
|
||||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||||
|
|
||||||
@@ -39,13 +39,14 @@ class SignalWCache(Signal):
|
|||||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
signal : Union[pd.Series, pd.DataFrame]
|
signal : Union[pd.Series, pd.DataFrame]
|
||||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
The expected format of the signal is like the data below (the order of index is not important and can be
|
||||||
|
automatically adjusted)
|
||||||
|
|
||||||
instrument datetime
|
instrument datetime
|
||||||
SH600000 2008-01-02 0.079704
|
SH600000 2008-01-02 0.079704
|
||||||
@@ -56,8 +57,8 @@ class SignalWCache(Signal):
|
|||||||
"""
|
"""
|
||||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||||
|
|
||||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
|
||||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
# the frequency of the signal may not align with the decision frequency of strategy
|
||||||
# so resampling from the data is necessary
|
# so resampling from the data is necessary
|
||||||
# the latest signal leverage more recent data and therefore is used in trading.
|
# the latest signal leverage more recent data and therefore is used in trading.
|
||||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||||
@@ -65,7 +66,7 @@ class SignalWCache(Signal):
|
|||||||
|
|
||||||
|
|
||||||
class ModelSignal(SignalWCache):
|
class ModelSignal(SignalWCache):
|
||||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
pred_scores = self.model.predict(dataset)
|
pred_scores = self.model.predict(dataset)
|
||||||
@@ -73,7 +74,7 @@ class ModelSignal(SignalWCache):
|
|||||||
pred_scores = pred_scores.iloc[:, 0]
|
pred_scores = pred_scores.iloc[:, 0]
|
||||||
super().__init__(pred_scores)
|
super().__init__(pred_scores)
|
||||||
|
|
||||||
def _update_model(self):
|
def _update_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
When using online data, update model in each bar as the following steps:
|
When using online data, update model in each bar as the following steps:
|
||||||
- update dataset with online data, the dataset should support online update
|
- update dataset with online data, the dataset should support online update
|
||||||
|
|||||||
@@ -149,6 +149,8 @@ class TradeCalendarManager:
|
|||||||
Tuple[int, int]:
|
Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
# potential performance issue
|
# potential performance issue
|
||||||
|
assert self.level_infra is not None
|
||||||
|
|
||||||
day_start = pd.Timestamp(self.start_time.date())
|
day_start = pd.Timestamp(self.start_time.date())
|
||||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||||
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
||||||
@@ -182,8 +184,8 @@ class TradeCalendarManager:
|
|||||||
Tuple[int, int]:
|
Tuple[int, int]:
|
||||||
the index of the range. **the left and right are closed**
|
the index of the range. **the left and right are closed**
|
||||||
"""
|
"""
|
||||||
left = bisect.bisect_right(self._calendar, start_time) - 1
|
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||||
right = bisect.bisect_right(self._calendar, end_time) - 1
|
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||||
left -= self.start_index
|
left -= self.start_index
|
||||||
right -= self.start_index
|
right -= self.start_index
|
||||||
|
|
||||||
@@ -201,14 +203,14 @@ class TradeCalendarManager:
|
|||||||
|
|
||||||
|
|
||||||
class BaseInfrastructure:
|
class BaseInfrastructure:
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
self.reset_infra(**kwargs)
|
self.reset_infra(**kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_support_infra(self) -> Set[str]:
|
def get_support_infra(self) -> Set[str]:
|
||||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||||
|
|
||||||
def reset_infra(self, **kwargs) -> None:
|
def reset_infra(self, **kwargs: Any) -> None:
|
||||||
support_infra = self.get_support_infra()
|
support_infra = self.get_support_infra()
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if k in support_infra:
|
if k in support_infra:
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ def long_short_backtest(
|
|||||||
for stock in long_stocks:
|
for stock in long_stocks:
|
||||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||||
continue
|
continue
|
||||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||||
if np.isnan(profit):
|
if np.isnan(profit):
|
||||||
long_profit.append(0)
|
long_profit.append(0)
|
||||||
else:
|
else:
|
||||||
@@ -348,17 +348,17 @@ def long_short_backtest(
|
|||||||
for stock in short_stocks:
|
for stock in short_stocks:
|
||||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||||
continue
|
continue
|
||||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||||
if np.isnan(profit):
|
if np.isnan(profit):
|
||||||
short_profit.append(0)
|
short_profit.append(0)
|
||||||
else:
|
else:
|
||||||
short_profit.append(-profit)
|
short_profit.append(profit * -1)
|
||||||
|
|
||||||
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
||||||
# exclude the suspend stock
|
# exclude the suspend stock
|
||||||
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
||||||
continue
|
continue
|
||||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||||
if np.isnan(profit):
|
if np.isnan(profit):
|
||||||
all_profit.append(0)
|
all_profit.append(0)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC):
|
|||||||
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
|
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
|
||||||
return _calendar[si : ei + 1]
|
return _calendar[si : ei + 1]
|
||||||
|
|
||||||
def locate_index(self, start_time, end_time, freq, future=False):
|
def locate_index(
|
||||||
|
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
|
||||||
|
):
|
||||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
start_time : str
|
start_time : pd.Timestamp
|
||||||
start of the time range.
|
start of the time range.
|
||||||
end_time : str
|
end_time : pd.Timestamp
|
||||||
end of the time range.
|
end of the time range.
|
||||||
freq : str
|
freq : str
|
||||||
time frequency, available: year/quarter/month/week/day.
|
time frequency, available: year/quarter/month/week/day.
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ def load_orders(
|
|||||||
Order(
|
Order(
|
||||||
row["instrument"],
|
row["instrument"],
|
||||||
row["amount"],
|
row["amount"],
|
||||||
int(row["order_type"]),
|
OrderDir(int(row["order_type"])),
|
||||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Generator, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from qlib.backtest.exchange import Exchange
|
from qlib.backtest.exchange import Exchange
|
||||||
@@ -122,7 +122,10 @@ class BaseStrategy:
|
|||||||
self.outer_trade_decision = outer_trade_decision
|
self.outer_trade_decision = outer_trade_decision
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
def generate_trade_decision(
|
||||||
|
self,
|
||||||
|
execute_result: list = None,
|
||||||
|
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
|
||||||
"""Generate trade decision in each trading bar
|
"""Generate trade decision in each trading bar
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ Motivation of index_data
|
|||||||
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
|
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, Tuple, Union, Callable, List
|
from typing import Dict, Tuple, Union, Callable, List
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
@@ -16,7 +18,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
def concat(data_list: Union[SingleData], axis=0) -> MultiData:
|
||||||
"""concat all SingleData by index.
|
"""concat all SingleData by index.
|
||||||
TODO: now just for SingleData.
|
TODO: now just for SingleData.
|
||||||
|
|
||||||
@@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
|
|||||||
raise ValueError(f"axis must be 0 or 1")
|
raise ValueError(f"axis must be 0 or 1")
|
||||||
|
|
||||||
|
|
||||||
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
|
def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:
|
||||||
"""concat all SingleData by new index.
|
"""concat all SingleData by new index.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -554,7 +556,7 @@ class SingleData(IndexData):
|
|||||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||||
)
|
)
|
||||||
|
|
||||||
def reindex(self, index: Index, fill_value=np.NaN):
|
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
|
||||||
"""reindex data and fill the missing value with np.NaN.
|
"""reindex data and fill the missing value with np.NaN.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -580,7 +582,7 @@ class SingleData(IndexData):
|
|||||||
pass
|
pass
|
||||||
return SingleData(tmp_data, index)
|
return SingleData(tmp_data, index)
|
||||||
|
|
||||||
def add(self, other: "SingleData", fill_value=0):
|
def add(self, other: SingleData, fill_value=0):
|
||||||
# TODO: add and __add__ are a little confusing.
|
# TODO: add and __add__ are a little confusing.
|
||||||
# This could be a more general
|
# This could be a more general
|
||||||
common_index = self.index | other.index
|
common_index = self.index | other.index
|
||||||
|
|||||||
Reference in New Issue
Block a user