mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Merge branch 'nested_decision_exe' of https://github.com/microsoft/qlib into rl-dummy
This commit is contained in:
@@ -12,7 +12,7 @@ import pandas as pd
|
||||
from ..data.data import D
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from .order import Order, OrderDir, OrderHelper
|
||||
|
||||
@@ -166,7 +166,7 @@ class Exchange:
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
|
||||
self.quote = quote_dict
|
||||
|
||||
@@ -186,13 +186,13 @@ class Exchange:
|
||||
|
||||
"""
|
||||
if direction is None:
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
return buy_limit or sell_limit
|
||||
elif direction == Order.BUY:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
elif direction == Order.SELL:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
@@ -242,6 +242,7 @@ class Exchange:
|
||||
raise ValueError("trade_account and position can only choose one")
|
||||
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time)
|
||||
# NOTE: order will be changed in this function
|
||||
trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position
|
||||
)
|
||||
@@ -256,27 +257,17 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def create_order(self, code, amount, start_time, end_time, direction) -> Order:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0]
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last)
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(
|
||||
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
|
||||
@@ -295,10 +286,7 @@ class Exchange:
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
return None
|
||||
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
|
||||
if res is not None:
|
||||
res = res.iloc[0]
|
||||
return res
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
@@ -471,6 +459,8 @@ class Exchange:
|
||||
"""
|
||||
Calculation of trade info
|
||||
|
||||
**NOTE**: Order will be changed in this function
|
||||
|
||||
:param order:
|
||||
:param position: Position
|
||||
:return: trade_val, trade_cost
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from typing import List, Union
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
@@ -318,6 +318,15 @@ class NestedExecutor(BaseExecutor):
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
"""Executor that simulate the true market"""
|
||||
|
||||
# available trade_types
|
||||
TT_SERIAL = "serial"
|
||||
## The orders will be executed serially in a sequence
|
||||
# In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
|
||||
TT_PARAL = "parallel"
|
||||
## The orders will be executed parallelly
|
||||
# In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
|
||||
# occur
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_per_step: str,
|
||||
@@ -329,6 +338,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_PARAL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -337,6 +347,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
trade_type: str
|
||||
please refer to the doc of `TT_SERIAL` & `TT_PARAL`
|
||||
"""
|
||||
super(SimulatorExecutor, self).__init__(
|
||||
time_per_step=time_per_step,
|
||||
@@ -352,6 +364,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
self.trade_type = trade_type
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
@@ -361,14 +375,45 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
the trade decision given by the strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Order]:
|
||||
get a list orders according to `self.trade_type`
|
||||
"""
|
||||
orders = trade_decision.get_decision()
|
||||
|
||||
if self.trade_type == self.TT_SERIAL:
|
||||
# Orders will be traded in a parallel way
|
||||
order_it = orders
|
||||
elif self.trade_type == self.TT_PARAL:
|
||||
# NOTE: !!!!!!!
|
||||
# Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
|
||||
# The parallel trading failure will be caused only by the confliction of money
|
||||
# Therefore, make the buying go first will make sure the confliction happen.
|
||||
# It equals to parallel trading after sorting the order by direction
|
||||
order_it = sorted(orders, key=lambda order: -order.direction)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def execute(self, trade_decision: BaseTradeDecision):
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
for order in trade_decision.get_decision():
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
# execute the order.
|
||||
# NOTE: The trade_account will be changed in this function
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
order, trade_account=self.trade_account
|
||||
)
|
||||
@@ -405,6 +450,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
# do nothing
|
||||
pass
|
||||
|
||||
# Account will not be changed in this function
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
|
||||
@@ -93,7 +93,7 @@ class Report:
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
_codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
if len(_temp_result) == 0:
|
||||
|
||||
@@ -7,7 +7,7 @@ from qlib.data.dataset.utils import convert_index_format
|
||||
|
||||
from qlib.utils import lazy_sort_index
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...utils.resam import resam_ts_data, ts_data_last
|
||||
from ...data.data import D
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
@@ -432,7 +432,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
@@ -454,13 +454,16 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
return self.TREND_MID
|
||||
else:
|
||||
_sample_signal = resam_ts_data(
|
||||
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
|
||||
self.signal[stock_id],
|
||||
pred_start_time,
|
||||
pred_end_time,
|
||||
method=ts_data_last,
|
||||
)
|
||||
# if EMA signal == 0 or None, return mid trend
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0:
|
||||
return self.TREND_MID
|
||||
# if EMA signal > 0, return long trend
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
elif _sample_signal > 0:
|
||||
return self.TREND_LONG
|
||||
# if EMA signal < 0, return short trend
|
||||
else:
|
||||
@@ -523,7 +526,7 @@ class ACStrategy(BaseStrategy):
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -590,12 +593,12 @@ class ACStrategy(BaseStrategy):
|
||||
# considering trade unit
|
||||
|
||||
sig_sam = (
|
||||
resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last")
|
||||
resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last)
|
||||
if order.stock_id in self.signal
|
||||
else None
|
||||
)
|
||||
|
||||
if sig_sam is None or sig_sam.iloc[0] is None:
|
||||
if sig_sam is None or np.isnan(sig_sam):
|
||||
# no signal, TWAP
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
if _amount_trade_unit is None:
|
||||
@@ -612,7 +615,7 @@ class ACStrategy(BaseStrategy):
|
||||
)
|
||||
else:
|
||||
# VA strategy
|
||||
kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0]
|
||||
kappa_tild = self.lamb / self.eta * sig_sam * sig_sam
|
||||
kappa = np.arccosh(kappa_tild / 2 + 1)
|
||||
amount_ratio = (
|
||||
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
|
||||
@@ -707,12 +710,36 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
|
||||
class FileOrderStrategy(BaseStrategy):
|
||||
"""
|
||||
Motivtaion:
|
||||
Motivation:
|
||||
- This class provides an interface for user to read orders from csv files.
|
||||
- It is supposed to be used in
|
||||
"""
|
||||
|
||||
def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : Union[IO, str, Path]
|
||||
this parameters will specify the info of expected orders
|
||||
|
||||
Here is an example of the content
|
||||
|
||||
1) Amount (**adjusted**) based strategy
|
||||
|
||||
datetime,instrument,amount,direction
|
||||
20200102, SH600519, 1000, sell
|
||||
20200103, SH600519, 1000, buy
|
||||
20200106, SH600519, 1000, sell
|
||||
|
||||
index_range : Tuple[int, int]
|
||||
the intra day time index range of the orders
|
||||
the left and right is closed.
|
||||
|
||||
If you want to get the index_range in intra-day
|
||||
- `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier
|
||||
# TODO: this is a index_range level limitation. We'll implement a more detailed limitation later.
|
||||
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
with get_io_object(file) as f:
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
|
||||
|
||||
@@ -197,7 +197,7 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
from .storage import HasingStockStorage
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
data_storage = self._data
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
@@ -211,10 +211,17 @@ class DataHandler(Serializable):
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
raise ValueError("proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
@@ -522,7 +529,7 @@ class DataHandlerLP(DataHandler):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
from .storage import HasingStockStorage
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
data_storage = self._get_df_by_key(data_key)
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
@@ -537,10 +544,17 @@ class DataHandlerLP(DataHandler):
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
raise ValueError("proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class BaseHandlerStorage:
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
proc_func: Callable = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""fetch data from the data storage
|
||||
@@ -24,6 +25,7 @@ class BaseHandlerStorage:
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
- if level is None, apply selector to df directly
|
||||
col_set : Union[str, List[str]]
|
||||
- if isinstance(col_set, str):
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
@@ -33,15 +35,24 @@ class BaseHandlerStorage:
|
||||
select several sets of meaningful columns, the returned data has multiple level
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
the dataframe fetched
|
||||
"""
|
||||
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def from_df(df: pd.DataFrame):
|
||||
raise NotImplementedError("from_df method is not implemented!")
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""whether the arg `proc_func` in `fetch` method is supported."""
|
||||
raise NotImplementedError("is_proc_func_supported method is not implemented!")
|
||||
|
||||
|
||||
class HasingStockStorage(BaseHandlerStorage):
|
||||
def __init__(self, df):
|
||||
@@ -105,3 +116,7 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""the arg `proc_func` in `fetch` method is not supported in HasingStockStorage"""
|
||||
return False
|
||||
|
||||
@@ -3,6 +3,8 @@ import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
from . import lazy_sort_index
|
||||
@@ -263,3 +265,36 @@ def resam_ts_data(
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
return feature
|
||||
|
||||
|
||||
def get_valid_value(series, last=True):
|
||||
"""get the first/last not nan value of pd.Series with single level index
|
||||
Parameters
|
||||
----------
|
||||
series : pd.Seires
|
||||
series should not be empty
|
||||
last : bool, optional
|
||||
wether to get the last valid value, by default True
|
||||
- if last is True, get the last valid value
|
||||
- else, get the first valid value
|
||||
|
||||
Returns
|
||||
-------
|
||||
Nan | float
|
||||
the first/last valid value
|
||||
"""
|
||||
return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0]
|
||||
|
||||
|
||||
def _ts_data_valid(ts_feature, last=False):
|
||||
"""get the first/last not nan value of pd.Series|DataFrame with single level index"""
|
||||
if isinstance(ts_feature, pd.DataFrame):
|
||||
return ts_feature.apply(lambda column: get_valid_value(column, last=last))
|
||||
elif isinstance(ts_feature, pd.Series):
|
||||
return get_valid_value(ts_feature, last=last)
|
||||
else:
|
||||
raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}")
|
||||
|
||||
|
||||
ts_data_last = partial(_ts_data_valid, last=False)
|
||||
ts_data_first = partial(_ts_data_valid, last=True)
|
||||
|
||||
Reference in New Issue
Block a user