mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
successful run random order gen in day script
This commit is contained in:
@@ -13,7 +13,7 @@ from qlib.tests.data import GetData
|
||||
from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class NestedDecisonExecutionWorkflow:
|
||||
class NestedDecisionExecutionWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
@@ -229,4 +229,4 @@ class NestedDecisonExecutionWorkflow:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(NestedDecisonExecutionWorkflow)
|
||||
fire.Fire(NestedDecisionExecutionWorkflow)
|
||||
|
||||
@@ -76,7 +76,7 @@ class Account:
|
||||
'kwargs': {
|
||||
"cash": init_cash
|
||||
},
|
||||
'model_path': "qlib.backtest.position",
|
||||
'module_path': "qlib.backtest.position",
|
||||
})
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
@@ -164,13 +164,14 @@ class Account:
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
if not self.current.skip_update():
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
|
||||
def update_report(self, trade_start_time, trade_end_time):
|
||||
"""update position history, report"""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.utils import TradeDecison
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
"""backtest funciton for the interaction of the outermost strategy and executor in the nested decison execution
|
||||
"""backtest funciton for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -15,7 +16,7 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
it records the trading report information
|
||||
"""
|
||||
return_value = {}
|
||||
for _decison in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("report"), return_value.get("indicator")
|
||||
|
||||
@@ -45,22 +46,24 @@ def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: TradeDecison = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
bar.update(trade_executor.trade_calendar.get_trade_step())
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
all_reports = {
|
||||
"{}{}".format(*parse_freq(_executor.time_per_step)): _executor.get_report()
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report()
|
||||
for _executor in all_executors
|
||||
if _executor.generate_report
|
||||
}
|
||||
all_indicators = {
|
||||
"{}{}".format(
|
||||
*parse_freq(_executor.time_per_step)
|
||||
*Freq.parse(_executor.time_per_step)
|
||||
): _executor.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
for _executor in all_executors
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -259,6 +260,16 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def create_order(self, code, amount, start_time, end_time, direction) -> Order:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
@@ -278,8 +289,20 @@ class Exchange:
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0]
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
return None
|
||||
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
|
||||
if res is not None:
|
||||
res = res.iloc[0]
|
||||
return res
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,12 @@ import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from .order import Order
|
||||
from .order import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import BaseTradeDecision, TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : TradeDecison
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -149,7 +149,7 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : TradeDecison
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -261,7 +261,7 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
def execute(self, trade_decision):
|
||||
return_value = {}
|
||||
for _decison in self.collect_data(trade_decision, return_value):
|
||||
for _decision in self.collect_data(trade_decision, return_value):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@@ -358,13 +358,12 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def execute(self, trade_decision):
|
||||
def execute(self, trade_decision: BaseTradeDecision):
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
order_generator = trade_decision.generator()
|
||||
for order in order_generator:
|
||||
for order in trade_decision.get_decision():
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Union, List, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,3 +42,192 @@ class Order:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
|
||||
Case 1:
|
||||
1. Outer strategy makes a decision. The decision is not available at the start of current interval
|
||||
2. After a period of time, the decision are updated and become available
|
||||
3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit`
|
||||
Case 2:
|
||||
1. The strategy is available at the start of the interval
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
The strategy who make the decision
|
||||
"""
|
||||
self.strategy = strategy
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
Decision not available
|
||||
concrete_decision:
|
||||
available
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple=None):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
self.idx_range = idx_range
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
if self.idx_range is None:
|
||||
# Default to get full index
|
||||
return 0, self.strategy.trade_calendar.get_trade_len() - 1
|
||||
return self.idx_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
return self.order_list
|
||||
|
||||
|
||||
# TODO: the orders below need to be discussed ------------------------------------
|
||||
class TradeDecisionWithOrderPool:
|
||||
"""trade decision that made by strategy"""
|
||||
|
||||
def __init__(self, strategy, order_pool):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
the original strategy that make the decision
|
||||
order_pool : list, optional
|
||||
the candinate order pool for generate trade decision
|
||||
"""
|
||||
super(TradeDecisionWithOrderPool, self).__init__(strategy)
|
||||
self.order_pool = order_pool
|
||||
self.order_list = []
|
||||
|
||||
def pop_order_pool(self, pop_len):
|
||||
if pop_len > len(self.order_pool):
|
||||
warnings.warn(
|
||||
f"pop len {pop_len} is too much length than order pool, cut it as pool length {len(self.order_pool)}"
|
||||
)
|
||||
pop_len = len(self.order_pool)
|
||||
res = self.order_pool[:pop_len]
|
||||
del self.order_pool[:pop_len]
|
||||
return res
|
||||
|
||||
def push_order_list(self, order_list):
|
||||
self.order_list.extend(order_list)
|
||||
|
||||
def get_decision(self):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
return self.order_list
|
||||
|
||||
def update(self, trade_calendar):
|
||||
"""make the original strategy update the enabled status of orders."""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
|
||||
class BaseDecisionUpdater:
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
"""[summary]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decision : BaseTradeDecision
|
||||
the trade decision to be updated
|
||||
trade_calendar : BaseTradeCalendar
|
||||
the trade calendar of inner execution
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
the updated decision
|
||||
"""
|
||||
raise NotImplementedError(f"This method is not implemented")
|
||||
|
||||
|
||||
class DecisionUpdaterWithOrderPool:
|
||||
def __init__(self, plan_config=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
plan_config : Dict[Tuple(int, float)], optional
|
||||
the plan config, by default None
|
||||
"""
|
||||
if plan_config is None:
|
||||
self.plan_config = [(0, 1)]
|
||||
else:
|
||||
self.plan_config = plan_config
|
||||
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
for _index, _ratio in self.plan_config:
|
||||
if trade_step == _index:
|
||||
pop_len = len(decision.order_pool) * _ratio
|
||||
pop_order_list = decision.pop_order_pool(pop_len)
|
||||
decision.push_order_list(pop_order_list)
|
||||
|
||||
@@ -30,6 +30,23 @@ class BasePosition:
|
||||
"""
|
||||
return False
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
"""
|
||||
check if is the stock in the position
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
the id of the stock
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if is the stock in the position
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
"""
|
||||
Parameters
|
||||
@@ -393,6 +410,10 @@ class InfPosition(BasePosition):
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
return True
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
pass
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from pandas.core import groupby
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.resam import parse_freq, resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
|
||||
@@ -78,6 +79,9 @@ class Report:
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if benchmark is None:
|
||||
return None
|
||||
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
@@ -94,6 +98,9 @@ class Report:
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
def cal_change(x):
|
||||
return (x + 1).prod()
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.backtest.order import Order
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.account import Account
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Tuple, Union, List, Set
|
||||
@@ -150,187 +146,3 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_calendar"]
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
# TODO: put it into order.py; and replace it with decision.py
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
self.strategy = strategy
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. concrete decision)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
Decision not available
|
||||
concrete_decision:
|
||||
available
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> "BaseTradeDecison":
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecison:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisonWO(BaseTradeDecision):
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
|
||||
|
||||
class TradeDecison(BaseTradeDecision):
|
||||
"""trade decision that made by strategy"""
|
||||
|
||||
def __init__(self, order_list, ori_strategy, init_enable=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
order_list : list
|
||||
the order list
|
||||
ori_strategy : BaseStrategy
|
||||
the original strategy that make the decison
|
||||
init_enable : bool, optional
|
||||
wether to enable order initially, default by False
|
||||
"""
|
||||
self.order_list = order_list
|
||||
self.ori_strategy = ori_strategy
|
||||
if init_enable:
|
||||
self.enable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
self.disable_dict = dict()
|
||||
else:
|
||||
self.enable_dict = dict()
|
||||
self.disable_dict = {_order.stock_id: _order for _order in self.order_list}
|
||||
|
||||
def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False):
|
||||
"""enable order set
|
||||
Parameters
|
||||
----------
|
||||
enable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be enabled, by default None
|
||||
- if all_enable is True, enable_set will be ignored
|
||||
- else, enable the order whose stock_id in enable_set
|
||||
all_enable : bool, optional
|
||||
wether to enable all order, by default False
|
||||
"""
|
||||
if all_enable is True:
|
||||
self.enable_dict.update(self.disable_dict)
|
||||
self.disable_dict.clear()
|
||||
if enable_set is not None:
|
||||
warnings.warn(f"`enable_set` is ignored because `all_enable` is set True")
|
||||
else:
|
||||
enable_set = set(enable_set)
|
||||
for _stock_id in enable_set:
|
||||
enable_order = self.disable_dict.get(_stock_id)
|
||||
if enable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in disable set")
|
||||
self.enable_order.update({_stock_id: enable_order})
|
||||
self.disable_dict.pop(_stock_id)
|
||||
|
||||
def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False):
|
||||
"""disable order set
|
||||
Parameters
|
||||
----------
|
||||
disable_set : Union[List[str], Set[str]], optional
|
||||
the order set that will be disabled, by default None
|
||||
- if all_disable is True, disable_set will be ignored
|
||||
- else, disable the order whose stock_id in disable_set
|
||||
all_disable : bool, optional
|
||||
wether to disable all order, by default False
|
||||
"""
|
||||
if all_disable is True:
|
||||
self.disable_dict.update(self.enable_dict)
|
||||
self.enable_dict.clear()
|
||||
if disable_set is not None:
|
||||
warnings.warn(f"`disable_set` is ignored because `all_disable` is set True")
|
||||
else:
|
||||
disable_set = set(disable_set)
|
||||
for _stock_id in disable_set:
|
||||
disable_order = self.enable_dict.get(_stock_id)
|
||||
if disable_order is None:
|
||||
raise ValueError(f"_stock_id {_stock_id} is not found in enable set")
|
||||
self.disable_dict.update({_stock_id: disable_order})
|
||||
self.enable_dict.pop(_stock_id)
|
||||
|
||||
def generator(self, only_enable=False, only_disable=False):
|
||||
"""get order generator used for iteration
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
yield from self.order_list
|
||||
elif not only_disable:
|
||||
yield from self.enable_dict.values()
|
||||
elif not only_enable:
|
||||
yield from self.disable_dict.values()
|
||||
|
||||
def get_order_list(self, only_enable=False, only_disable=False):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
if not only_disable and not only_enable:
|
||||
return self.order_list
|
||||
elif not only_disable:
|
||||
return list(self.enable_dict.values())
|
||||
elif not only_enable:
|
||||
return list(self.disable_dict.values())
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager):
|
||||
"""
|
||||
make the original strategy update the enabled status of orders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
the trade calendar for sub strategy
|
||||
"""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
@@ -11,7 +11,7 @@ import warnings
|
||||
from ..log import get_module_logger
|
||||
from ..backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.resam import parse_freq, NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY, NORM_FREQ_MINUTE
|
||||
from ..utils.resam import Freq
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -35,12 +35,12 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"""
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = parse_freq(freq)
|
||||
_count, _freq = Freq.parse(freq)
|
||||
_freq_scaler = {
|
||||
NORM_FREQ_MINUTE: 240 * 252,
|
||||
NORM_FREQ_DAY: 252,
|
||||
NORM_FREQ_WEEK: 50,
|
||||
NORM_FREQ_MONTH: 12,
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 252,
|
||||
Freq.NORM_FREQ_DAY: 252,
|
||||
Freq.NORM_FREQ_WEEK: 50,
|
||||
Freq.NORM_FREQ_MONTH: 12,
|
||||
}
|
||||
return _freq_scaler[_freq] / _count
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.utils import TradeDecison
|
||||
from ...backtest.order import Order, BaseTradeDecision
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
@@ -247,7 +246,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
return TradeDecison(order_list=sell_order_list + buy_order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=sell_order_list + buy_order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
@@ -344,4 +343,4 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=order_list, ori_strategy=self)
|
||||
|
||||
@@ -6,7 +6,7 @@ This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import TradeDecison
|
||||
from ...backtest.order import BaseTradeDecision
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -127,7 +127,7 @@ class OrderGenWInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
@@ -191,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=order_list, ori_strategy=self)
|
||||
|
||||
@@ -7,9 +7,9 @@ from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
@@ -17,7 +17,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -25,8 +25,8 @@ class TWAPStrategy(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison
|
||||
the trade decison of outer strategy which this startegy relies
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -57,25 +57,35 @@ class TWAPStrategy(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
outer_order_generator = outer_trade_decision.generator()
|
||||
for order in outer_order_generator:
|
||||
for order in outer_trade_decision.get_decision():
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
# strategy is not available. Give an empty decision
|
||||
if len(self.outer_trade_decision.get_decision()) == 0:
|
||||
return TradeDecisionWO(order_list=[], strategy=self)
|
||||
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
start_idx, end_idx = self.outer_trade_decision.get_range_limit()
|
||||
trade_len = end_idx - start_idx + 1
|
||||
|
||||
if trade_step < start_idx:
|
||||
# It is not time to start trading
|
||||
return TradeDecisionWO(order_list=[], strategy=self)
|
||||
|
||||
rel_trade_step = trade_step - start_idx # trade_step relative to start_idx
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
@@ -84,8 +94,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
outer_order_generator = self.outer_trade_decision.generator(only_enable=True)
|
||||
for order in outer_order_generator:
|
||||
for order in self.outer_trade_decision.get_decision():
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
@@ -96,21 +105,21 @@ class TWAPStrategy(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - rel_trade_step)
|
||||
# without considering trade unit
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
# floor((trade_unit_cnt + trade_len - rel_trade_step) / (trade_len - rel_trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - rel_trade_step + 1))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - rel_trade_step - 1) // (trade_len - rel_trade_step) * _amount_trade_unit
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or rel_trade_step == trade_len - 1):
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
@@ -126,7 +135,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecisionWO(order_list=order_list, strategy=self)
|
||||
|
||||
|
||||
class SBBStrategyBase(BaseStrategy):
|
||||
@@ -140,7 +149,7 @@ class SBBStrategyBase(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -148,8 +157,8 @@ class SBBStrategyBase(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison
|
||||
the trade decison of outer strategy which this startegy relies
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -178,11 +187,11 @@ class SBBStrategyBase(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
@@ -336,7 +345,7 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# in the first one of two adjacent bars, store the trend for the second one to use
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
@@ -346,7 +355,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -426,7 +435,7 @@ class ACStrategy(BaseStrategy):
|
||||
lamb: float = 1e-6,
|
||||
eta: float = 2.5e-6,
|
||||
window_size: int = 20,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -503,11 +512,11 @@ class ACStrategy(BaseStrategy):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecison = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
@@ -592,13 +601,13 @@ class ACStrategy(BaseStrategy):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return TradeDecison(order_list=order_list, ori_strategy=self)
|
||||
return TradeDecision(order_list=order_list, ori_strategy=self)
|
||||
|
||||
|
||||
class RandomOrderStrategy(BaseStrategy):
|
||||
|
||||
def __init__(self,
|
||||
time_range: Tuple = ("9:30", "15:00"), # The range is closed on both left and right.
|
||||
index_range: Tuple[int, int], # The range is closed on both left and right.
|
||||
sample_ratio: float = 1.,
|
||||
volume_ratio: float = 0.01,
|
||||
market: str = "all",
|
||||
@@ -607,10 +616,10 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
time_range : Tuple
|
||||
the intra day time range of the orders
|
||||
index_range : Tuple
|
||||
the intra day time index range of the orders
|
||||
the left and right is closed.
|
||||
# TODO: this is a time_range level limitation. We'll implement a more detailed limitation later.
|
||||
# TODO: this is a index_range level limitation. We'll implement a more detailed limitation later.
|
||||
sample_ratio : float
|
||||
the ratio of all orders are sampled
|
||||
volume_ratio : float
|
||||
@@ -621,12 +630,27 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
"""
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.time_range = time_range
|
||||
self.index_range = index_range
|
||||
self.sample_ratio = sample_ratio
|
||||
self.volume_ratio = volume_ratio
|
||||
self.market = market
|
||||
exch: Exchange = self.common_infra.get("exchange")
|
||||
self.volume = D.features(D.instruments("market"), ["Mean($volume, 10)"], start_time=exch.start_time, end_time=exch.end_time)
|
||||
exch: Exchange = self.common_infra.get("trade_exchange")
|
||||
self.volume = D.features(D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time)
|
||||
self.volume_df = self.volume.iloc[:, 0].unstack()
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
return super().generate_trade_decision(execute_result=execute_result)
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step)
|
||||
|
||||
order_list = []
|
||||
for direction in Order.SELL, Order.BUY:
|
||||
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
|
||||
order_list.append(
|
||||
self.common_infra.get("trade_exchange").create_order(
|
||||
code=stock_id,
|
||||
amount=volume * self.volume_ratio,
|
||||
start_time=step_time_start,
|
||||
end_time=step_time_end,
|
||||
direction=direction, # 1 for buy
|
||||
))
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
@@ -7,7 +7,8 @@ from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, TradeDecison
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.order import BaseTradeDecision
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -15,16 +16,16 @@ class BaseStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : TradeDecison, optional
|
||||
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
the trade decision of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decision, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : LevelInfrastructure, optional
|
||||
level shared infrastructure for backtesting, including trade calendar
|
||||
@@ -34,14 +35,14 @@ class BaseStrategy:
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self.trade_calendar: TradeCalendarManager = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure):
|
||||
if not hasattr(self, "common_infra"):
|
||||
@@ -62,7 +63,7 @@ class BaseStrategy:
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
- reset `outer_trade_decision`, used to make split decison
|
||||
- reset `outer_trade_decision`, used to make split decision
|
||||
"""
|
||||
if level_infra is not None:
|
||||
self.reset_level_infra(level_infra)
|
||||
@@ -79,19 +80,19 @@ class BaseStrategy:
|
||||
Parameters
|
||||
----------
|
||||
execute_result : List[object], optional
|
||||
the executed result for trade decison, by default None
|
||||
the executed result for trade decision, by default None
|
||||
- When call the generate_trade_decision firstly, `execute_result` could be None
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]:
|
||||
def update_trade_decision(self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decison : TradeDecison
|
||||
the trade decison that will be updated
|
||||
trade_decision : BaseTradeDecision
|
||||
the trade decision that will be updated
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
@@ -125,7 +126,7 @@ class ModelStrategy(BaseStrategy):
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -161,7 +162,7 @@ class RLStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -184,7 +185,7 @@ class RLIntStrategy(RLStrategy):
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
outer_trade_decision: TradeDecison = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
|
||||
@@ -7,58 +7,7 @@ from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
from . import lazy_sort_index
|
||||
from ..config import C
|
||||
|
||||
NORM_FREQ_MONTH = "month"
|
||||
NORM_FREQ_WEEK = "week"
|
||||
NORM_FREQ_DAY = "day"
|
||||
NORM_FREQ_MINUTE = "minute"
|
||||
|
||||
|
||||
def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse freq into a unified format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq: Tuple[int, str]
|
||||
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(parse_freq("day"))
|
||||
(1, "day" )
|
||||
print(parse_freq("2mon"))
|
||||
(2, "month")
|
||||
print(parse_freq("10w"))
|
||||
(10, "week")
|
||||
|
||||
"""
|
||||
freq = freq.lower()
|
||||
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
|
||||
if match_obj is None:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": NORM_FREQ_MONTH,
|
||||
"mon": NORM_FREQ_MONTH,
|
||||
"week": NORM_FREQ_WEEK,
|
||||
"w": NORM_FREQ_WEEK,
|
||||
"day": NORM_FREQ_DAY,
|
||||
"d": NORM_FREQ_DAY,
|
||||
"minute": NORM_FREQ_MINUTE,
|
||||
"min": NORM_FREQ_MINUTE,
|
||||
}
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
from .time import Freq
|
||||
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
"""
|
||||
@@ -80,13 +29,13 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
np.ndarray
|
||||
The calendar with frequency freq_sam
|
||||
"""
|
||||
raw_count, freq_raw = parse_freq(freq_raw)
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
raw_count, freq_raw = Freq.parse(freq_raw)
|
||||
sam_count, freq_sam = Freq.parse(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
|
||||
# if freq_sam is xminute, divide each trading day into several bars evenly
|
||||
if freq_sam == NORM_FREQ_MINUTE:
|
||||
if freq_sam == Freq.NORM_FREQ_MINUTE:
|
||||
|
||||
def cal_sam_minute(x, sam_minutes):
|
||||
"""
|
||||
@@ -119,7 +68,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
else:
|
||||
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
|
||||
|
||||
if freq_raw != NORM_FREQ_MINUTE:
|
||||
if freq_raw != Freq.NORM_FREQ_MINUTE:
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
@@ -130,15 +79,15 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == NORM_FREQ_DAY:
|
||||
if freq_sam == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam == NORM_FREQ_WEEK:
|
||||
elif freq_sam == Freq.NORM_FREQ_WEEK:
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam == NORM_FREQ_MONTH:
|
||||
elif freq_sam == Freq.NORM_FREQ_MONTH:
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_count]
|
||||
@@ -180,7 +129,7 @@ def get_resam_calendar(
|
||||
|
||||
"""
|
||||
|
||||
_, norm_freq = parse_freq(freq)
|
||||
_, norm_freq = Freq.parse(freq)
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
@@ -189,7 +138,7 @@ def get_resam_calendar(
|
||||
freq, freq_sam = freq, None
|
||||
except (ValueError, KeyError):
|
||||
freq_sam = freq
|
||||
if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]:
|
||||
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
|
||||
try:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
|
||||
@@ -200,7 +149,7 @@ def get_resam_calendar(
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "1min"
|
||||
elif norm_freq == NORM_FREQ_MINUTE:
|
||||
elif norm_freq == Freq.NORM_FREQ_MINUTE:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
@@ -224,15 +173,15 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
|
||||
_freq = freq
|
||||
except (ValueError, KeyError):
|
||||
_, norm_freq = parse_freq(freq)
|
||||
if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]:
|
||||
_, norm_freq = Freq.parse(freq)
|
||||
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
|
||||
try:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache)
|
||||
_freq = "day"
|
||||
except (ValueError, KeyError):
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
|
||||
_freq = "1min"
|
||||
elif norm_freq == NORM_FREQ_MINUTE:
|
||||
elif norm_freq == Freq.NORM_FREQ_MINUTE:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
|
||||
_freq = "1min"
|
||||
else:
|
||||
|
||||
115
qlib/utils/time.py
Normal file
115
qlib/utils/time.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Time related utils are compiled in this script
|
||||
"""
|
||||
import bisect
|
||||
from datetime import time
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from numpy import append
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_min_cal() -> List[time]:
|
||||
"""
|
||||
get the minute level calendar in day period
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[time]:
|
||||
|
||||
"""
|
||||
cal = []
|
||||
for ts in list(pd.date_range("9:30", "11:29", freq="1min")) + list(pd.date_range("13:00", "14:59", freq="1min")):
|
||||
cal.append(ts.time())
|
||||
return cal
|
||||
|
||||
|
||||
class Freq:
|
||||
NORM_FREQ_MONTH = "month"
|
||||
NORM_FREQ_WEEK = "week"
|
||||
NORM_FREQ_DAY = "day"
|
||||
NORM_FREQ_MINUTE = "minute"
|
||||
SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE]
|
||||
|
||||
MIN_CAL = get_min_cal()
|
||||
|
||||
def __init__(self, freq: str) -> None:
|
||||
self.count, self.base = self.parse(freq)
|
||||
|
||||
@staticmethod
|
||||
def parse(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse freq into a unified format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq: Tuple[int, str]
|
||||
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(Freq.parse("day"))
|
||||
(1, "day" )
|
||||
print(Freq.parse("2mon"))
|
||||
(2, "month")
|
||||
print(Freq.parse("10w"))
|
||||
(10, "week")
|
||||
|
||||
"""
|
||||
freq = freq.lower()
|
||||
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
|
||||
if match_obj is None:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": Freq.NORM_FREQ_MONTH,
|
||||
"mon": Freq.NORM_FREQ_MONTH,
|
||||
"week": Freq.NORM_FREQ_WEEK,
|
||||
"w": Freq.NORM_FREQ_WEEK,
|
||||
"day": Freq.NORM_FREQ_DAY,
|
||||
"d": Freq.NORM_FREQ_DAY,
|
||||
"minute": Freq.NORM_FREQ_MINUTE,
|
||||
"min": Freq.NORM_FREQ_MINUTE,
|
||||
}
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
|
||||
def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]:
|
||||
"""
|
||||
get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency
|
||||
Parameters
|
||||
----------
|
||||
start : str
|
||||
e.g. "9:30"
|
||||
end : str
|
||||
e.g. "14:30"
|
||||
freq : str
|
||||
"1min"
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
The index of start and end in the calendar. Both left and right are **closed**
|
||||
"""
|
||||
start = pd.Timestamp(start).time()
|
||||
end = pd.Timestamp(end).time()
|
||||
freq = Freq(freq)
|
||||
in_day_cal = Freq.MIN_CAL[::freq.count]
|
||||
left_idx = bisect.bisect_left(in_day_cal, start)
|
||||
right_idx = bisect.bisect_right(in_day_cal, end) - 1
|
||||
return left_idx, right_idx
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_day_min_idx_range("8:30", "14:59", "10min"))
|
||||
@@ -16,7 +16,7 @@ from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
@@ -344,17 +344,17 @@ class PortAnaRecord(RecordTemp):
|
||||
indicator_analysis_freq = [indicator_analysis_freq]
|
||||
|
||||
self.risk_analysis_freq = [
|
||||
"{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in risk_analysis_freq
|
||||
"{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq
|
||||
]
|
||||
self.indicator_analysis_freq = [
|
||||
"{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
|
||||
"{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
|
||||
]
|
||||
self.indicator_analysis_method = indicator_analysis_method
|
||||
|
||||
def _get_report_freq(self, executor_config):
|
||||
ret_freq = []
|
||||
if executor_config["kwargs"].get("generate_report", False):
|
||||
_count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"])
|
||||
_count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"])
|
||||
ret_freq.append(f"{_count}{_freq}")
|
||||
if "sub_env" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
|
||||
|
||||
Reference in New Issue
Block a user