mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
add InfPosition
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import copy
|
||||
from typing import Union
|
||||
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
@@ -91,17 +92,53 @@ def get_exchange(
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}
|
||||
):
|
||||
trade_account = Account(
|
||||
init_cash=account,
|
||||
benchmark_config={
|
||||
def create_account_instance(start_time, end_time, benchmark: str, account: float, pos_type: str="Position") -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
# There should be a post-step to process the report.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
start time of the benchmark
|
||||
end_time :
|
||||
end time of the benchmark
|
||||
benchmark : str
|
||||
the benchmark for reporting
|
||||
account : Union[float, str]
|
||||
information for describing how to creating the account
|
||||
For `float`
|
||||
Using Account with a normal position
|
||||
For `str`:
|
||||
Using account with a specific Position
|
||||
"""
|
||||
kwargs = {
|
||||
"init_cash": account,
|
||||
"benchmark_config": {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
)
|
||||
"pos_type": pos_type
|
||||
}
|
||||
return Account(**kwargs)
|
||||
|
||||
|
||||
def get_strategy_executor(start_time,
|
||||
end_time,
|
||||
strategy: BaseStrategy,
|
||||
executor: BaseExecutor,
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, str] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
|
||||
trade_account = create_account_instance(start_time=start_time,
|
||||
end_time=end_time,
|
||||
benchmark=benchmark,
|
||||
account=account,
|
||||
pos_type=pos_type)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
@@ -117,19 +154,47 @@ def get_strategy_executor(
|
||||
return trade_strategy, trade_executor
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
def backtest(start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
pos_type: str = "Position"):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark,
|
||||
account,
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
report_dict, indicator_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict, indicator_dict
|
||||
|
||||
|
||||
def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
def collect_data(start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
pos_type: str = "Position"):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark,
|
||||
account,
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
|
||||
|
||||
import copy
|
||||
from qlib.utils import init_instance_by_config
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from .position import Position
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .report import Report, Indicator
|
||||
from .order import Order
|
||||
from .exchange import Exchange
|
||||
@@ -62,22 +63,32 @@ class AccumulatedInfo:
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, freq: str = "day", benchmark_config: dict = {}):
|
||||
def __init__(self, init_cash: float=1e9, freq: str = "day", benchmark_config: dict = {}, pos_type:str = "Position"):
|
||||
self.pos_type = pos_type
|
||||
self.init_vars(init_cash, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
|
||||
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current = Position(cash=init_cash)
|
||||
self.current: BasePosition = init_instance_by_config({
|
||||
'class': self.pos_type,
|
||||
'kwargs': {
|
||||
"cash": init_cash
|
||||
},
|
||||
'model_path': "qlib.backtest.position",
|
||||
})
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.indicator = Indicator()
|
||||
self.positions = {}
|
||||
|
||||
# trading related matric(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False):
|
||||
"""reset freq and report of account
|
||||
|
||||
@@ -102,7 +113,7 @@ class Account:
|
||||
return self.positions
|
||||
|
||||
def get_cash(self):
|
||||
return self.current.position["cash"]
|
||||
return self.current.get_cash()
|
||||
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
# update turnover
|
||||
@@ -124,6 +135,11 @@ class Account:
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
if self.current.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
@@ -142,7 +158,8 @@ class Account:
|
||||
def update_bar_count(self):
|
||||
"""at the end of the trading bar, update holding bar, count of stock"""
|
||||
# update holding day count
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
if not self.current.skip_update():
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
@@ -243,11 +260,14 @@ class Account:
|
||||
elif atomic is False and inner_order_indicators is None:
|
||||
raise ValueError("inner_order_indicators is necessary in unatomic executor")
|
||||
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
if generate_report:
|
||||
# report is portfolio related analysis
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.clear()
|
||||
|
||||
if atomic:
|
||||
|
||||
@@ -282,7 +282,10 @@ class NestedExecutor(BaseExecutor):
|
||||
self.inner_strategy.alter_decision(trade_decision)
|
||||
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
|
||||
|
||||
execute_result.extend(_inner_execute_result)
|
||||
inner_order_indicators.append(self.inner_executor.get_trade_indicator().get_order_indicator)
|
||||
|
||||
|
||||
@@ -4,30 +4,182 @@
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .order import Order
|
||||
|
||||
"""
|
||||
Position module
|
||||
"""
|
||||
|
||||
"""
|
||||
current state of position
|
||||
a typical example is :{
|
||||
<instrument_id>: {
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
},
|
||||
}
|
||||
class BasePosition:
|
||||
"""
|
||||
The Position want to maintain the position like a dictionary
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
def __init__(self, cash=0., *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
"""
|
||||
def skip_update(self) -> bool:
|
||||
"""
|
||||
Should we skip updating operation for this position
|
||||
For example, updating is meaningless for InfPosition
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
should we skip the updating operator
|
||||
"""
|
||||
return False
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
order : Order
|
||||
the order to update the position
|
||||
trade_val : float
|
||||
the trade value(money) of dealing results
|
||||
cost : float
|
||||
the trade cost of the dealing results
|
||||
trade_price : float
|
||||
the trade price of the dealing results
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_order` method")
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
"""
|
||||
Updating the latest price of the order
|
||||
The useful when clearing balance at each bar end
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id :
|
||||
the id of the stock
|
||||
price : float
|
||||
the price to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update stock price` method")
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
"""
|
||||
calculate the value of the all assets except cash in the position
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the value(money) of all the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
|
||||
def get_stock_list(self) -> List:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
"""
|
||||
get the latest price of the stock
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code :
|
||||
the code of the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
"""
|
||||
get the amount of the stock
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code :
|
||||
the code of the stock
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the amount of the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount` method")
|
||||
|
||||
def get_cash(self) -> float:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the cash in position
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
"""
|
||||
generate stock amount dict {stock_id : amount of stock}
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict:
|
||||
{stock_id : amount of stock}
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool=False) -> Dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_stock : bool
|
||||
If only_stock=True, the weight of each stock in total stock will be returned
|
||||
If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict:
|
||||
{stock_id : value weight of stock in the position}
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
def update_weight_all(self):
|
||||
"""
|
||||
Updating the position weight;
|
||||
|
||||
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
|
||||
# and before updating weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
|
||||
class Position:
|
||||
"""Position"""
|
||||
class Position(BasePosition):
|
||||
"""Position
|
||||
|
||||
current state of position
|
||||
a typical example is :{
|
||||
<instrument_id>: {
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cash=0, position_dict={}, now_account_value=0):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
@@ -37,23 +189,35 @@ class Position:
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
def init_stock(self, stock_id, amount, price=None):
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id :
|
||||
the id of the stock
|
||||
amount : float
|
||||
the amount of the stock
|
||||
price :
|
||||
the price when buying the init stock
|
||||
"""
|
||||
self.position[stock_id] = {}
|
||||
self.position[stock_id]["amount"] = amount
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
else:
|
||||
# exist, add amount
|
||||
self.position[stock_id]["amount"] += trade_amount
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
@@ -66,11 +230,11 @@ class Position:
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
|
||||
self.del_stock(stock_id)
|
||||
self._del_stock(stock_id)
|
||||
|
||||
self.position["cash"] += trade_val - cost
|
||||
|
||||
def del_stock(self, stock_id):
|
||||
def _del_stock(self, stock_id):
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
@@ -80,10 +244,10 @@ class Position:
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
self.buy_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
self._buy_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
elif order.direction == Order.SELL:
|
||||
# SELL
|
||||
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
self._sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
@@ -122,6 +286,7 @@ class Position:
|
||||
return self.position[code]["amount"]
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
@@ -215,3 +380,55 @@ class Position:
|
||||
self.position = positions
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
|
||||
|
||||
class InfPosition(BasePosition):
|
||||
"""
|
||||
Position with infinite cash and amount.
|
||||
|
||||
This is useful for generating random orders.
|
||||
"""
|
||||
def skip_update(self) -> bool:
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
pass
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
pass
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
infinity stock value
|
||||
"""
|
||||
return np.inf
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
"""the price of the inf position is meaningless"""
|
||||
return np.nan
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_cash(self) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def update_weight_all(self):
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This module is not well maintained.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -17,9 +17,15 @@ from ..tests.config import CSI300_BENCH
|
||||
|
||||
|
||||
class Report:
|
||||
# daily report of the account
|
||||
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
# update report
|
||||
'''
|
||||
Motivation:
|
||||
Report is for supporting portfolio related metrics.
|
||||
|
||||
Implementation:
|
||||
daily report of the account
|
||||
contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
update report
|
||||
'''
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.backtest.order import Order
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.account import Account
|
||||
@@ -158,7 +159,7 @@ class BaseTradeDecision:
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the concrete decision of the order
|
||||
get the **concrete decision** (e.g. concrete decision)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
@@ -173,13 +174,15 @@ class BaseTradeDecision:
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
NOT_AVAIL = 0
|
||||
NO_UPDATE = 1
|
||||
NEW_UPDATE = 2
|
||||
def update(self, trade_step: int, trade_len: int) -> "BaseTradeDecison":
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> "BaseTradeDecison":
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
@@ -187,23 +190,28 @@ class BaseTradeDecision:
|
||||
BaseTradeDecison:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_step, trade_len)
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the dealing time of the order
|
||||
return the expected step range for limiting the decision execution time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisonWO(BaseTradeDecision):
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
|
||||
|
||||
class TradeDecison(BaseTradeDecision):
|
||||
@@ -316,6 +324,13 @@ class TradeDecison(BaseTradeDecision):
|
||||
elif not only_enable:
|
||||
return list(self.disable_dict.values())
|
||||
|
||||
def update(self, trade_step, trade_len):
|
||||
"""make the original strategy update the enabled status of orders."""
|
||||
self.ori_strategy.update_trade_decision(self, trade_step, trade_len)
|
||||
def update(self, trade_calendar: TradeCalendarManager):
|
||||
"""
|
||||
make the original strategy update the enabled status of orders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
the trade calendar for sub strategy
|
||||
"""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This strategy is not well maintained
|
||||
"""
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from qlib.backtest.position import Position
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -328,6 +329,8 @@ class WeightStrategyBase(ModelStrategy):
|
||||
if pred_score is None:
|
||||
return []
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
assert(isinstance(current_temp, Position)) # Avoid InfPosition
|
||||
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
)
|
||||
|
||||
@@ -76,8 +76,6 @@ class TWAPStrategy(BaseStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
@@ -204,8 +202,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
@@ -527,7 +523,7 @@ class ACStrategy(BaseStrategy):
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
# update outer trade decision
|
||||
self.outer_trade_decision.update(trade_step, trade_len)
|
||||
self.outer_trade_decision.update(self.trade_calendar)
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
@@ -602,7 +598,7 @@ class ACStrategy(BaseStrategy):
|
||||
class RandomOrderStrategy(BaseStrategy):
|
||||
|
||||
def __init__(self,
|
||||
time_range: Tuple = ("9:30", "15:00"), # left closed and right closed.
|
||||
time_range: Tuple = ("9:30", "15:00"), # The range is closed on both left and right.
|
||||
sample_ratio: float = 1.,
|
||||
volume_ratio: float = 0.01,
|
||||
market: str = "all",
|
||||
@@ -614,6 +610,7 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
time_range : Tuple
|
||||
the intra day time range of the orders
|
||||
the left and right is closed.
|
||||
# TODO: this is a time_range level limitation. We'll implement a more detailed limitation later.
|
||||
sample_ratio : float
|
||||
the ratio of all orders are sampled
|
||||
volume_ratio : float
|
||||
@@ -632,6 +629,4 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
self.volume = D.features(D.instruments("market"), ["Mean($volume, 10)"], start_time=exch.start_time, end_time=exch.end_time)
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
|
||||
return super().generate_trade_decision(execute_result=execute_result)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeDecison
|
||||
from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, TradeDecison
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -84,19 +84,23 @@ class BaseStrategy:
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_step: int, trade_len: int) -> BaseTradeDecision:
|
||||
"""update trade decision in each step of inner execution, this method enable all order
|
||||
def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decison : TradeDecison
|
||||
the trade decison that will be updated
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision:
|
||||
"""
|
||||
if trade_step == 0:
|
||||
trade_decison.enable(all_enable=True)
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
|
||||
"""
|
||||
@@ -108,6 +112,9 @@ class BaseStrategy:
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
"""
|
||||
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
self.outer_trade_decision = outer_trade_decision
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user