mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
fix data format bug & twap peeking strategy
This commit is contained in:
@@ -405,7 +405,7 @@ class NestedExecutor(BaseExecutor):
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator()
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
|
||||
)
|
||||
else:
|
||||
# do nothing and just step forward
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Union, Callable, Iterable, Dict
|
||||
from typing import List, Text, Tuple, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
import inspect
|
||||
@@ -280,6 +280,21 @@ class BaseOrderIndicator:
|
||||
|
||||
pass
|
||||
|
||||
def to_series(self) -> Dict[Text, pd.Series]:
|
||||
"""return the metrics as pandas series
|
||||
|
||||
for example: { "ffr":
|
||||
SH600068 NaN
|
||||
SH600079 1.0
|
||||
SH600266 NaN
|
||||
...
|
||||
SZ300692 NaN
|
||||
SZ300719 NaN,
|
||||
...
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `to_series` method")
|
||||
|
||||
|
||||
class PandasSingleMetric:
|
||||
"""Each SingleMetric is based on pd.Series."""
|
||||
@@ -429,3 +444,6 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
metric_dict[metric] = tmp_metric.metric
|
||||
return metric_dict
|
||||
|
||||
def to_series(self):
|
||||
return {k: v.metric for k, v in self.data.items()}
|
||||
|
||||
@@ -274,8 +274,8 @@ class Indicator:
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
def record(self, trade_start_time):
|
||||
self.order_indicator_his[trade_start_time] = self.order_indicator.data
|
||||
self.trade_indicator_his[trade_start_time] = self.trade_indicator
|
||||
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
||||
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
||||
|
||||
def _update_order_trade_info(self, trade_info: list):
|
||||
amount = dict()
|
||||
@@ -587,8 +587,10 @@ class Indicator:
|
||||
)
|
||||
)
|
||||
|
||||
def get_order_indicator(self):
|
||||
return self.order_indicator
|
||||
def get_order_indicator(self, raw: bool = False):
|
||||
if raw:
|
||||
return self.order_indicator
|
||||
return self.order_indicator.to_series()
|
||||
|
||||
def get_trade_indicator(self):
|
||||
return self.trade_indicator
|
||||
|
||||
@@ -63,11 +63,11 @@ class TWAPStrategy(BaseStrategy):
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision.get_decision():
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
# Don't peek the future information
|
||||
# if not self.trade_exchange.is_stock_tradable(
|
||||
# stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
# ):
|
||||
# continue
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
|
||||
stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user