1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

black format & add comments & add randStrategy direction

This commit is contained in:
Young
2021-06-28 08:16:51 +00:00
committed by you-n-g
parent 72c9593aa7
commit 27f0db669f
13 changed files with 132 additions and 102 deletions

View File

@@ -92,7 +92,9 @@ def get_exchange(
return init_instance_by_config(exchange, accept_types=Exchange)
def create_account_instance(start_time, end_time, benchmark: str, account: float, pos_type: str="Position") -> Account:
def create_account_instance(
start_time, end_time, benchmark: str, account: float, pos_type: str = "Position"
) -> Account:
"""
# TODO: is very strange pass benchmark_config in the account(maybe for report)
# There should be a post-step to process the report.
@@ -119,26 +121,25 @@ def create_account_instance(start_time, end_time, benchmark: str, account: float
"start_time": start_time,
"end_time": end_time,
},
"pos_type": pos_type
"pos_type": pos_type,
}
return Account(**kwargs)
def get_strategy_executor(start_time,
end_time,
strategy: BaseStrategy,
executor: BaseExecutor,
benchmark: str = "SH000300",
account: Union[float, str] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
def get_strategy_executor(
start_time,
end_time,
strategy: BaseStrategy,
executor: BaseExecutor,
benchmark: str = "SH000300",
account: Union[float, str] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
trade_account = create_account_instance(start_time=start_time,
end_time=end_time,
benchmark=benchmark,
account=account,
pos_type=pos_type)
trade_account = create_account_instance(
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
)
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
@@ -154,14 +155,16 @@ def get_strategy_executor(start_time,
return trade_strategy, trade_executor
def backtest(start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
pos_type: str = "Position"):
def backtest(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
pos_type: str = "Position",
):
trade_strategy, trade_executor = get_strategy_executor(
start_time,
@@ -178,14 +181,16 @@ def backtest(start_time,
return report_dict, indicator_dict
def collect_data(start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
pos_type: str = "Position"):
def collect_data(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
pos_type: str = "Position",
):
trade_strategy, trade_executor = get_strategy_executor(
start_time,

View File

@@ -63,7 +63,9 @@ class AccumulatedInfo:
class Account:
def __init__(self, init_cash: float=1e9, freq: str = "day", benchmark_config: dict = {}, pos_type:str = "Position"):
def __init__(
self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position"
):
self.pos_type = pos_type
self.init_vars(init_cash, freq, benchmark_config)
@@ -71,13 +73,13 @@ class Account:
# init cash
self.init_cash = init_cash
self.current: BasePosition = init_instance_by_config({
'class': self.pos_type,
'kwargs': {
"cash": init_cash
},
'module_path': "qlib.backtest.position",
})
self.current: BasePosition = init_instance_by_config(
{
"class": self.pos_type,
"kwargs": {"cash": init_cash},
"module_path": "qlib.backtest.position",
}
)
self.accum_info = AccumulatedInfo()
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)

View File

@@ -23,7 +23,9 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
return return_value.get("report"), return_value.get("indicator")
def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None):
def collect_data_loop(
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
):
"""Generator for collecting the trade decision data for rl training
Parameters
@@ -68,7 +70,7 @@ def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format( *Freq.parse(_executor.time_per_step))
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.get_trade_indicator()
return_value.update({"report": all_reports, "indicator": all_indicators})

View File

@@ -2,8 +2,10 @@
# Licensed under the MIT License.
# TODO: rename it with decision.py
from __future__ import annotations
# try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
from qlib.backtest.utils import TradeCalendarManager
@@ -59,6 +61,7 @@ class BaseTradeDecision:
1. The outer strategy's decision is available at the start of the interval
2. Same as `case 1.3`
"""
def __init__(self, strategy: BaseStrategy):
"""
Parameters
@@ -125,7 +128,8 @@ class TradeDecisionWO(BaseTradeDecision):
Trade Decision (W)ith (O)rder.
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple=None):
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple = None):
super().__init__(strategy)
self.order_list = order_list
self.idx_range = idx_range
@@ -198,8 +202,7 @@ class TradeDecisionWithOrderPool:
class BaseDecisionUpdater:
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
"""[summary]
"""
Parameters
----------
decision : BaseTradeDecision

View File

@@ -15,7 +15,8 @@ class BasePosition:
The Position want to maintain the position like a dictionary
Please refer to the `Position` class for the position
"""
def __init__(self, cash=0., *args, **kwargs) -> None:
def __init__(self, cash=0.0, *args, **kwargs) -> None:
pass
def skip_update(self) -> bool:
@@ -46,7 +47,6 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `check_stock` method")
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
"""
Parameters
@@ -86,6 +86,7 @@ class BasePosition:
the value(money) of all the stock
"""
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
def get_stock_list(self) -> List:
"""
Get the list of stocks in the position.
@@ -140,7 +141,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
def get_stock_weight_dict(self, only_stock: bool=False) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
"""
generate stock weight dict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade date
@@ -399,13 +400,13 @@ class Position(BasePosition):
self.position["now_account_value"] = now_account_value
class InfPosition(BasePosition):
"""
Position with infinite cash and amount.
This is useful for generating random orders.
"""
def skip_update(self) -> bool:
""" Updating state is meaningless for InfPosition """
return True

View File

@@ -18,7 +18,7 @@ from ..tests.config import CSI300_BENCH
class Report:
'''
"""
Motivation:
Report is for supporting portfolio related metrics.
@@ -26,7 +26,8 @@ class Report:
daily report of the account
contain those followings: returns, costs turnovers, accounts, cash, bench, value
update report
'''
"""
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
"""
Parameters

View File

@@ -140,7 +140,6 @@ class BaseInfrastructure:
self.reset_infra(**infra_dict)
class CommonInfrastructure(BaseInfrastructure):
def get_support_infra(self):
return ["trade_account", "trade_exchange"]

View File

@@ -15,6 +15,7 @@ class TopkDropoutStrategy(ModelStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
@@ -104,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
return TradeDecisionWO([], self)
if self.only_tradable:
# If The strategy only consider tradable stock when make decision
# It needs following actions to filter stocks
@@ -256,6 +257,7 @@ class WeightStrategyBase(ModelStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
@@ -332,9 +334,9 @@ class WeightStrategyBase(ModelStrategy):
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
return TradeDecisionWO([], self)
current_temp = copy.deepcopy(self.trade_position)
assert(isinstance(current_temp, Position)) # Avoid InfPosition
assert isinstance(current_temp, Position) # Avoid InfPosition
target_weight_position = self.generate_target_weight_position(
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time

View File

@@ -102,7 +102,7 @@ class TWAPStrategy(BaseStrategy):
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
start_idx, end_idx = get_start_end_idx(self, self.outer_trade_decision)
trade_len = end_idx - start_idx + 1
trade_len = end_idx - start_idx + 1
if trade_step < start_idx:
# It is not time to start trading
@@ -137,12 +137,16 @@ class TWAPStrategy(BaseStrategy):
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - rel_trade_step) / (trade_len - rel_trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - rel_trade_step + 1))
_order_amount = (
(trade_unit_cnt + trade_len - rel_trade_step - 1) // (trade_len - rel_trade_step) * _amount_trade_unit
(trade_unit_cnt + trade_len - rel_trade_step - 1)
// (trade_len - rel_trade_step)
* _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or rel_trade_step == trade_len - 1):
if self.trade_amount[order.stock_id] > 1e-5 and (
_order_amount < 1e-5 or rel_trade_step == trade_len - 1
):
_order_amount = self.trade_amount[order.stock_id]
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
@@ -173,6 +177,7 @@ class SBBStrategyBase(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
@@ -225,8 +230,7 @@ class SBBStrategyBase(BaseStrategy):
self.trade_trend = {}
self.trade_amount = {}
# init the trade amount of order and predicted trade trend
outer_order_generator = outer_trade_decision.generator()
for order in outer_order_generator:
for order in outer_trade_decision.get_decision():
self.trade_trend[order.stock_id] = self.TREND_MID
self.trade_amount[order.stock_id] = order.amount
@@ -248,8 +252,7 @@ class SBBStrategyBase(BaseStrategy):
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
order_list = []
# for each order in in self.outer_trade_decision
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
for order in outer_order_generator:
for order in self.outer_trade_decision.get_decision():
# get the price trend
if trade_step % 2 == 0:
# in the first of two adjacent bars, predict the price trend
@@ -379,9 +382,11 @@ class SBBStrategyEMA(SBBStrategyBase):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
"""
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
@@ -463,6 +468,7 @@ class ACStrategy(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
lamb: float = 1e-6,
@@ -555,8 +561,7 @@ class ACStrategy(BaseStrategy):
if outer_trade_decision is not None:
self.trade_amount = {}
# init the trade amount of order and predicted trade trend
outer_order_generator = outer_trade_decision.generator()
for order in outer_order_generator:
for order in outer_trade_decision.get_decision():
self.trade_amount[order.stock_id] = order.amount
def generate_trade_decision(self, execute_result=None):
@@ -564,8 +569,6 @@ class ACStrategy(BaseStrategy):
trade_step = self.trade_calendar.get_trade_step()
# get the total count of trading step
trade_len = self.trade_calendar.get_trade_len()
# update outer trade decision
self.outer_trade_decision.update(self.trade_calendar)
# update the order amount
if execute_result is not None:
@@ -575,8 +578,7 @@ class ACStrategy(BaseStrategy):
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
order_list = []
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
for order in outer_order_generator:
for order in self.outer_trade_decision.get_decision():
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
@@ -638,14 +640,16 @@ class ACStrategy(BaseStrategy):
class RandomOrderStrategy(BaseStrategy):
def __init__(self,
index_range: Tuple[int, int], # The range is closed on both left and right.
sample_ratio: float = 1.,
volume_ratio: float = 0.01,
market: str = "all",
*args,
**kwargs):
def __init__(
self,
index_range: Tuple[int, int], # The range is closed on both left and right.
sample_ratio: float = 1.0,
volume_ratio: float = 0.01,
market: str = "all",
direction: int = Order.BUY,
*args,
**kwargs,
):
"""
Parameters
----------
@@ -667,9 +671,12 @@ class RandomOrderStrategy(BaseStrategy):
self.sample_ratio = sample_ratio
self.volume_ratio = volume_ratio
self.market = market
self.direction = direction
exch: Exchange = self.common_infra.get("trade_exchange")
# TODO: this can't be online
self.volume = D.features(D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time)
self.volume = D.features(
D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time
)
self.volume_df = self.volume.iloc[:, 0].unstack()
def generate_trade_decision(self, execute_result=None):
@@ -677,15 +684,15 @@ class RandomOrderStrategy(BaseStrategy):
step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step)
order_list = []
for direction in Order.SELL, Order.BUY:
if step_time_start in self.volume_df:
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
order_list.append(
self.common_infra.get("trade_exchange").create_order(
code=stock_id,
amount=volume * self.volume_ratio,
start_time=step_time_start,
end_time=step_time_end,
direction=direction, # 1 for buy
))
if step_time_start in self.volume_df:
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
order_list.append(
self.common_infra.get("trade_exchange").create_order(
code=stock_id,
amount=volume * self.volume_ratio,
start_time=step_time_start,
end_time=step_time_end,
direction=self.direction,
)
)
return TradeDecisionWO(order_list, self, self.index_range)

View File

@@ -213,7 +213,7 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
self.backend = kwargs.get("backend", {})
@staticmethod
def instruments(market: Union[List, str]="all", filter_pipe: Union[List, None]=None):
def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None):
"""Get the general config dictionary for a base market adding several dynamic filters.
Parameters

View File

@@ -85,7 +85,9 @@ class BaseStrategy:
"""
raise NotImplementedError("generate_trade_decision is not implemented!")
def update_trade_decision(self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]:
def update_trade_decision(
self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
) -> Union[BaseTradeDecision, None]:
"""
update trade decision in each step of inner execution, this method enable all order

View File

@@ -9,6 +9,7 @@ from . import lazy_sort_index
from ..config import C
from .time import Freq, cal_sam_minute
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
"""
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam

View File

@@ -14,7 +14,7 @@ import functools
@functools.lru_cache(maxsize=240)
def get_min_cal(shift: int=0) -> List[time]:
def get_min_cal(shift: int = 0) -> List[time]:
"""
get the minute level calendar in day period
@@ -30,8 +30,9 @@ def get_min_cal(shift: int=0) -> List[time]:
"""
cal = []
for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) +\
list(pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift)):
for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) + list(
pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift)
):
cal.append(ts.time())
return cal
@@ -115,7 +116,7 @@ def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]:
start = pd.Timestamp(start).time()
end = pd.Timestamp(end).time()
freq = Freq(freq)
in_day_cal = Freq.MIN_CAL[::freq.count]
in_day_cal = Freq.MIN_CAL[:: freq.count]
left_idx = bisect.bisect_left(in_day_cal, start)
right_idx = bisect.bisect_right(in_day_cal, end) - 1
return left_idx, right_idx
@@ -141,15 +142,19 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp:
"""
cal = get_min_cal(C.min_data_shift)[::sam_minutes]
idx = bisect.bisect_right(cal, x.time()) - 1
date, new_time = x.date(), cal[idx]
date, new_time = x.date(), cal[idx]
return pd.Timestamp(
datetime(date.year,
month=date.month,
day=date.day,
hour=new_time.hour,
minute=new_time.minute,
second=new_time.second,
microsecond=new_time.microsecond))
datetime(
date.year,
month=date.month,
day=date.day,
hour=new_time.hour,
minute=new_time.minute,
second=new_time.second,
microsecond=new_time.microsecond,
)
)
if __name__ == "__main__":
print(get_day_min_idx_range("8:30", "14:59", "10min"))