mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
successful run random order gen in day script
This commit is contained in:
@@ -76,7 +76,7 @@ class Account:
|
||||
'kwargs': {
|
||||
"cash": init_cash
|
||||
},
|
||||
'model_path': "qlib.backtest.position",
|
||||
'module_path': "qlib.backtest.position",
|
||||
})
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
@@ -164,13 +164,14 @@ class Account:
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
if not self.current.skip_update():
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
|
||||
def update_report(self, trade_start_time, trade_end_time):
|
||||
"""update position history, report"""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.utils import TradeDecison
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
"""backtest funciton for the interaction of the outermost strategy and executor in the nested decison execution
|
||||
"""backtest funciton for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -15,7 +16,7 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
it records the trading report information
|
||||
"""
|
||||
return_value = {}
|
||||
for _decison in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("report"), return_value.get("indicator")
|
||||
|
||||
@@ -45,22 +46,24 @@ def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: TradeDecison = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
bar.update(trade_executor.trade_calendar.get_trade_step())
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
all_reports = {
|
||||
"{}{}".format(*parse_freq(_executor.time_per_step)): _executor.get_report()
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report()
|
||||
for _executor in all_executors
|
||||
if _executor.generate_report
|
||||
}
|
||||
all_indicators = {
|
||||
"{}{}".format(
|
||||
*parse_freq(_executor.time_per_step)
|
||||
*Freq.parse(_executor.time_per_step)
|
||||
): _executor.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
for _executor in all_executors
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -259,6 +260,16 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def create_order(self, code, amount, start_time, end_time, direction) -> Order:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
@@ -278,8 +289,20 @@ class Exchange:
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0]
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
return None
|
||||
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
|
||||
if res is not None:
|
||||
res = res.iloc[0]
|
||||
return res
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,12 @@ import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from .order import Order
|
||||
from .order import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import BaseTradeDecision, TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : TradeDecison
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -149,7 +149,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : TradeDecison
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -261,7 +261,7 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
def execute(self, trade_decision):
|
||||
return_value = {}
|
||||
for _decison in self.collect_data(trade_decision, return_value):
|
||||
for _decision in self.collect_data(trade_decision, return_value):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@@ -358,13 +358,12 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def execute(self, trade_decision):
|
||||
def execute(self, trade_decision: BaseTradeDecision):
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
order_generator = trade_decision.generator()
|
||||
for order in order_generator:
|
||||
for order in trade_decision.get_decision():
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Union, List, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,3 +42,192 @@ class Order:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
|
||||
Case 1:
|
||||
1. Outer strategy makes a decision. The decision is not available at the start of current interval
|
||||
2. After a period of time, the decision are updated and become available
|
||||
3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit`
|
||||
Case 2:
|
||||
1. The strategy is available at the start of the interval
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
The strategy who make the decision
|
||||
"""
|
||||
self.strategy = strategy
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
Decision not available
|
||||
concrete_decision:
|
||||
available
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple=None):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
self.idx_range = idx_range
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
if self.idx_range is None:
|
||||
# Default to get full index
|
||||
return 0, self.strategy.trade_calendar.get_trade_len() - 1
|
||||
return self.idx_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
return self.order_list
|
||||
|
||||
|
||||
# TODO: the orders below need to be discussed ------------------------------------
|
||||
class TradeDecisionWithOrderPool:
|
||||
"""trade decision that made by strategy"""
|
||||
|
||||
def __init__(self, strategy, order_pool):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
the original strategy that make the decision
|
||||
order_pool : list, optional
|
||||
the candinate order pool for generate trade decision
|
||||
"""
|
||||
super(TradeDecisionWithOrderPool, self).__init__(strategy)
|
||||
self.order_pool = order_pool
|
||||
self.order_list = []
|
||||
|
||||
def pop_order_pool(self, pop_len):
|
||||
if pop_len > len(self.order_pool):
|
||||
warnings.warn(
|
||||
f"pop len {pop_len} is too much length than order pool, cut it as pool length {len(self.order_pool)}"
|
||||
)
|
||||
pop_len = len(self.order_pool)
|
||||
res = self.order_pool[:pop_len]
|
||||
del self.order_pool[:pop_len]
|
||||
return res
|
||||
|
||||
def push_order_list(self, order_list):
|
||||
self.order_list.extend(order_list)
|
||||
|
||||
def get_decision(self):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
return self.order_list
|
||||
|
||||
def update(self, trade_calendar):
|
||||
"""make the original strategy update the enabled status of orders."""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
|
||||
class BaseDecisionUpdater:
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
"""[summary]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decision : BaseTradeDecision
|
||||
the trade decision to be updated
|
||||
trade_calendar : BaseTradeCalendar
|
||||
the trade calendar of inner execution
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
the updated decision
|
||||
"""
|
||||
raise NotImplementedError(f"This method is not implemented")
|
||||
|
||||
|
||||
class DecisionUpdaterWithOrderPool:
|
||||
def __init__(self, plan_config=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
plan_config : Dict[Tuple(int, float)], optional
|
||||
the plan config, by default None
|
||||
"""
|
||||
if plan_config is None:
|
||||
self.plan_config = [(0, 1)]
|
||||
else:
|
||||
self.plan_config = plan_config
|
||||
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
for _index, _ratio in self.plan_config:
|
||||
if trade_step == _index:
|
||||
pop_len = len(decision.order_pool) * _ratio
|
||||
pop_order_list = decision.pop_order_pool(pop_len)
|
||||
decision.push_order_list(pop_order_list)
|
||||
|
||||
@@ -30,6 +30,23 @@ class BasePosition:
|
||||
"""
|
||||
return False
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
"""
|
||||
check if is the stock in the position
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
the id of the stock
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if is the stock in the position
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
"""
|
||||
Parameters
|
||||
@@ -393,6 +410,10 @@ class InfPosition(BasePosition):
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
return True
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
pass
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from pandas.core import groupby
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.resam import parse_freq, resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
|
||||
@@ -78,6 +79,9 @@ class Report:
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if benchmark is None:
|
||||
return None
|
||||
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
@@ -94,6 +98,9 @@ class Report:
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
def cal_change(x):
|
||||
return (x + 1).prod()
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.backtest.order import Order
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.account import Account
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Tuple, Union, List, Set
|
||||
@@ -150,187 +146,3 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_calendar"]
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
# TODO: put it into order.py; and replace it with decision.py
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
self.strategy = strategy
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. concrete decision)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
Decision not available
|
||||
concrete_decision:
|
||||
available
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> "BaseTradeDecison":
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecison:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisonWO(BaseTradeDecision):
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
|
||||
|
||||
class TradeDecison(BaseTradeDecision):
|
||||
"""trade decision that made by strategy"""
|
||||
|
||||
def __init__(self, order_list, ori_strategy, init_enable=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
order_list : list
|
||||
the order list
|
||||
ori_strategy : BaseStrategy
|
||||
the original strategy that make the decison
|
||||
init_enable : bool, optional
|
||||
wether to enable order initially, default by False
|
||||
"""
|
||||
self.order_list = order_list
|
||||
self.ori_strategy = ori_strategy
|
||||
if init_enable:
|
||||
self.enable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
self.disable_dict = dict()
|
||||
else:
|
||||
self.enable_dict = dict()
|
||||
self.disable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
|
||||
def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False):
|
||||
"""enable order set
|
||||
Parameters
|
||||
----------
|
||||
enable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be enabled, by default None
|
||||
- if all_enable is True, enable_set will be ignored
|
||||
- else, enable the order whose stock_id in enable_set
|
||||
all_enable : bool, optional
|
||||
wether to enable all order, by default False
|
||||
"""
|
||||
if all_enable is True:
|
||||
self.enable_dict.update(self.disable_dict)
|
||||
self.disable_dict.clear()
|
||||
if enable_set is not None:
|
||||
warnings.warn(f"`enable_set` is ignored because `all_enable` is set True")
|
||||
else:
|
||||
enable_set = set(enable_set)
|
||||
for _stock_id in enable_set:
|
||||
enable_order = self.disable_dict.get(_stock_id)
|
||||
if enable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in disable set")
|
||||
self.enable_order.update({_stock_id: enable_order})
|
||||
self.disable_dict.pop(_stock_id)
|
||||
|
||||
def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False):
|
||||
"""disable order set
|
||||
Parameters
|
||||
----------
|
||||
disable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be disabled, by default None
|
||||
- if all_disable is True, disable_set will be ignored
|
||||
- else, disable the order whose stock_id in disable_set
|
||||
all_disable : bool, optional
|
||||
wether to disable all order, by default False
|
||||
"""
|
||||
if all_disable is True:
|
||||
self.disable_dict.update(self.enable_dict)
|
||||
self.enable_dict.clear()
|
||||
if disable_set is not None:
|
||||
warnings.warn(f"`disable_set` is ignored because `all_disable` is set True")
|
||||
else:
|
||||
disable_set = set(disable_set)
|
||||
for _stock_id in disable_set:
|
||||
disable_order = self.enable_dict.get(_stock_id)
|
||||
if disable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in enable set")
|
||||
self.disable_dict.update({_stock_id: disable_order})
|
||||
self.enable_dict.pop(_stock_id)
|
||||
|
||||
def generator(self, only_enable=False, only_disable=False):
|
||||
"""get order generator used for iteration
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
yield from self.order_list
|
||||
elif not only_disable:
|
||||
yield from self.enable_dict.values()
|
||||
elif not only_enable:
|
||||
yield from self.disable_dict.values()
|
||||
|
||||
def get_order_list(self, only_enable=False, only_disable=False):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
return self.order_list
|
||||
elif not only_disable:
|
||||
return list(self.enable_dict.values())
|
||||
elif not only_enable:
|
||||
return list(self.disable_dict.values())
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager):
|
||||
"""
|
||||
make the original strategy update the enabled status of orders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
the trade calendar for sub strategy
|
||||
"""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
Reference in New Issue
Block a user