mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Simulator & action interpreter
This commit is contained in:
@@ -349,4 +349,4 @@ def format_decisions(
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order", "backtest", "BaseExecutor", "CommonInfrastructure"]
|
||||
__all__ = ["Order", "backtest"]
|
||||
|
||||
@@ -179,8 +179,8 @@ class OrderHelper:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
start_time=None if start_time is None else pd.Timestamp(start_time),
|
||||
end_time=None if end_time is None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
@@ -530,7 +530,7 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
|
||||
0
qlib/rl/order_execution/from_neutrader/__init__.py
Normal file
0
qlib/rl/order_execution/from_neutrader/__init__.py
Normal file
26
qlib/rl/order_execution/from_neutrader/config.py
Normal file
26
qlib/rl/order_execution/from_neutrader/config.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
seed: int = 42
|
||||
output_dir: Optional[Path] = None
|
||||
checkpoint_dir: Optional[Path] = None
|
||||
tb_log_dir: Optional[Path] = None
|
||||
debug: bool = False
|
||||
use_cuda: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.
|
||||
trade_unit: Optional[float] = 100.
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
11
qlib/rl/order_execution/from_neutrader/executor.py
Normal file
11
qlib/rl/order_execution/from_neutrader/executor.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from .strategy import RLStrategyBase
|
||||
|
||||
|
||||
class RLNestedExecutor(NestedExecutor):
|
||||
# RL nested executor
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
if isinstance(self.inner_strategy, RLStrategyBase):
|
||||
self.inner_strategy.post_exe_step(inner_exe_res)
|
||||
164
qlib/rl/order_execution/from_neutrader/feature.py
Normal file
164
qlib/rl/order_execution/from_neutrader/feature.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from .highfreq_ops import DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
from qlib.contrib.ops.high_freq import DayCumsum
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlibConfig:
|
||||
provider_uri_day: Path
|
||||
provider_uri_1min: Path
|
||||
feature_root_dir: Path
|
||||
feature_columns_today: List[str]
|
||||
feature_columns_yesterday: List[str]
|
||||
|
||||
|
||||
_dataset = None
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, pool_size: int = 200):
|
||||
self.pool_size = pool_size
|
||||
self.contents = dict()
|
||||
self.keys = collections.deque()
|
||||
|
||||
def put(self, key, item):
|
||||
if self.has(key):
|
||||
self.keys.remove(key)
|
||||
self.keys.append(key)
|
||||
self.contents[key] = item
|
||||
while len(self.contents) > self.pool_size:
|
||||
self.contents.pop(self.keys.popleft())
|
||||
|
||||
def get(self, key):
|
||||
return self.contents[key]
|
||||
|
||||
def has(self, key):
|
||||
return key in self.contents
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
|
||||
def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH,
|
||||
columns_today: List[str], columns_yesterday: List[str], _internal: bool = False):
|
||||
assert _internal, 'Init function of data wrapper is for internal use only.'
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
self.feature_cache = LRUCache()
|
||||
self.backtest_cache = LRUCache()
|
||||
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False):
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
|
||||
dataset = self.backtest_dataset if backtest else self.feature_dataset
|
||||
|
||||
if backtest:
|
||||
dataset = self.backtest_dataset
|
||||
cache = self.backtest_cache
|
||||
else:
|
||||
dataset = self.feature_dataset
|
||||
cache = self.feature_cache
|
||||
|
||||
if cache.has((start_time, end_time, stock_id)):
|
||||
return cache.get((start_time, end_time, stock_id))
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
cache.put((start_time, end_time, stock_id), data)
|
||||
return data
|
||||
|
||||
|
||||
def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
global _dataset
|
||||
|
||||
provider_uri_map = {
|
||||
"day": config.provider_uri_day.as_posix(),
|
||||
"1min": config.provider_uri_1min.as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan,
|
||||
Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D
|
||||
|
||||
if part is None:
|
||||
feature_path = config.feature_root_dir / 'feature.pkl'
|
||||
backtest_path = config.feature_root_dir / 'backtest.pkl'
|
||||
else:
|
||||
feature_path = config.feature_root_dir / 'feature' / (part + '.pkl')
|
||||
backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl')
|
||||
|
||||
with feature_path.open('rb') as f:
|
||||
print(feature_path)
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open('rb') as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
_dataset = DataWrapper(
|
||||
feature_dataset,
|
||||
backtest_dataset,
|
||||
config.feature_columns_today,
|
||||
config.feature_columns_yesterday,
|
||||
_internal=True
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False):
|
||||
assert _dataset is not None, 'You must call init_qlib() before doing this.'
|
||||
|
||||
if backtest:
|
||||
fields = ['$close', '$volume']
|
||||
else:
|
||||
fields = _dataset.columns_yesterday if yesterday else _dataset.columns_today
|
||||
|
||||
data = _dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0., index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip('0') for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
223
qlib/rl/order_execution/from_neutrader/highfreq_ops.py
Normal file
223
qlib/rl/order_execution/from_neutrader/highfreq_ops.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
|
||||
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Date Using Memcache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_calendar:
|
||||
array of date.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(
|
||||
list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
def get_calendar_minute(freq='day', future=False):
|
||||
"""Load High-Freq Calendar Minute Using Memcache"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(
|
||||
list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(
|
||||
instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(
|
||||
instrument, start_index, end_index, freq)
|
||||
return series_feature.loc[series_condition]
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class IsInf(ElemOperator):
|
||||
"""IsInf Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is inf
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return np.isinf(series)
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, left=None, right=None):
|
||||
self.left = left
|
||||
self.right = right
|
||||
if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.left: self.right]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.left is None else self.left
|
||||
rr = 0 if self.right is None else abs(self.right)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
251
qlib/rl/order_execution/from_neutrader/state.py
Normal file
251
qlib/rl/order_execution/from_neutrader/state.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
EPSILON = 1e-7
|
||||
|
||||
|
||||
class FlowDirection(str, Enum):
|
||||
ACQUIRE = "acquire"
|
||||
LIQUIDATE = "liquidate"
|
||||
|
||||
|
||||
def _round_time(time: int, granularity: int) -> int:
|
||||
return time - time % granularity
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseEpisodicState(abc.ABC):
|
||||
"""
|
||||
Base class for episodic states.
|
||||
"""
|
||||
|
||||
# requirements
|
||||
start_time: int
|
||||
end_time: int
|
||||
time_per_step: int
|
||||
vol_limit: Optional[float] # TODO: meaning?
|
||||
price_func: Callable[[str], np.ndarray] # TODO: meaning?
|
||||
volume_func: Callable[[], np.ndarray] # TODO: meaning?
|
||||
on_step_end: Optional[Callable[..., None]] # TODO: meaning?
|
||||
on_episode_end: Optional[Callable[..., None]] # TODO: meaning?
|
||||
asset_num: int # TODO: meaning?
|
||||
|
||||
# agent states
|
||||
num_step: int = field(init=False) # Number of steps
|
||||
cur_time: int = field(init=False) # Current time
|
||||
cur_step: int = field(init=False, default=0)
|
||||
exec_vol: Optional[np.ndarray] = field(init=False, default=None) # Execution history
|
||||
last_step_duration: int = field(init=False)
|
||||
position: float = field(init=False)
|
||||
position_history: np.ndarray = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cur_time = self.start_time
|
||||
rounded_start_time = _round_time(self.start_time, self.time_per_step)
|
||||
|
||||
# TODO: why not rounding end time?
|
||||
self.num_step = math.floor((self.end_time - rounded_start_time) / self.time_per_step)
|
||||
|
||||
def logs(self) -> dict:
|
||||
# Base logging information shared across all subclasses.
|
||||
# You can call logs = super().logs() to get these default logs and use logs.update(...) to add other logging
|
||||
# information or override it completely to remove these logging fields.
|
||||
return {
|
||||
"logs": {
|
||||
"stop_time": self.cur_time - self.start_time,
|
||||
"stop_step": self.cur_step,
|
||||
},
|
||||
"history": {
|
||||
"volume": self.execution_history(),
|
||||
},
|
||||
}
|
||||
|
||||
def execution_history(self) -> np.ndarray:
|
||||
return np.pad(self.exec_vol, (0, self.end_time - self.start_time - len(self.exec_vol)))
|
||||
|
||||
def next_duration(self) -> int:
|
||||
left, right = self.next_interval()
|
||||
return right - left
|
||||
|
||||
def next_interval(self) -> Tuple[int, int]:
|
||||
left = _round_time(self.cur_time, self.time_per_step)
|
||||
right = left + self.time_per_step
|
||||
return max(left, self.start_time) - self.start_time, min(right, self.end_time) - self.start_time
|
||||
|
||||
@classmethod
|
||||
def get_init_field_names(cls):
|
||||
ret = []
|
||||
for f in fields(cls):
|
||||
if f.init:
|
||||
ret.append(f.name)
|
||||
return ret
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntraDaySingleAssetDataSchema:
|
||||
"""
|
||||
In the current context, raw should be a DataFrame with `datetime` as index and
|
||||
(at least) `$vwap0`, `$volume0`, `$close0` as columns.
|
||||
`processed` should be a DataFrame of 240x6, which is the same as `processed_prev`.
|
||||
"""
|
||||
|
||||
date: pd.Timestamp
|
||||
stock_id: str
|
||||
start_time: int
|
||||
end_time: int
|
||||
target: float
|
||||
flow_dir: FlowDirection
|
||||
raw: pd.DataFrame
|
||||
processed: pd.DataFrame
|
||||
processed_prev: pd.DataFrame
|
||||
|
||||
def get_price(self, type: Literal['deal', 'close'] = 'deal'):
|
||||
if type == 'deal':
|
||||
return self.raw['$price'].values
|
||||
elif type == 'close':
|
||||
return self.raw['$close0'].values
|
||||
|
||||
def get_volume(self):
|
||||
return self.raw['$volume0'].values
|
||||
|
||||
def get_processed_data(self, type: Literal['today', 'yesterday'] = 'today'):
|
||||
if type == 'today':
|
||||
return self.processed.to_numpy()
|
||||
elif type == 'yesterday':
|
||||
return self.processed_prev.to_numpy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SAOEEpisodicState(BaseEpisodicState):
|
||||
"""
|
||||
Global state of the whole time horizon.
|
||||
"""
|
||||
|
||||
# requirements
|
||||
target: float
|
||||
target_limit: float
|
||||
flow_dir: FlowDirection
|
||||
|
||||
# calculated statistics
|
||||
turnover: Optional[float] = field(init=False)
|
||||
baseline_twap: Optional[float] = field(init=False)
|
||||
baseline_vwap: Optional[float] = field(init=False)
|
||||
exec_avg_price: Optional[float] = field(init=False)
|
||||
pa_twap: Optional[float] = field(init=False)
|
||||
pa_vwap: Optional[float] = field(init=False)
|
||||
pa_close: Optional[float] = field(init=False)
|
||||
fulfill_rate: Optional[float] = field(init=False)
|
||||
|
||||
market_price: np.ndarray = field(init=False) # deal price, might be different from close
|
||||
market_close: np.ndarray = field(init=False) # close price
|
||||
market_volume: np.ndarray = field(init=False)
|
||||
|
||||
# NOTE: this is a temporary design to make it compatible with old qlib integration framework. As long as callback
|
||||
# functions are passed correctly, this field should be removed from this class.
|
||||
last_interval: Tuple[int, int] = field(default=(0, 0), init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.target >= 0
|
||||
assert self.asset_num == 1
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
self.market_volume = self.volume_func()[self.start_time : self.end_time]
|
||||
self.market_price = self.price_func("deal")[self.start_time : self.end_time]
|
||||
self.market_close = self.price_func("close")[self.start_time : self.end_time]
|
||||
self.position = self.target
|
||||
self.position_history = np.full((self.num_step + 1), np.nan)
|
||||
self.position_history[0] = self.position
|
||||
self.baseline_twap = np.mean(self.market_price)
|
||||
if self.market_volume.sum() == 0:
|
||||
self.baseline_vwap = self.baseline_twap
|
||||
else:
|
||||
self.baseline_vwap = np.average(self.market_price, weights=self.market_volume)
|
||||
|
||||
def update_stats(self) -> None:
|
||||
market_price = self.market_price[: len(self.exec_vol)]
|
||||
self.turnover = (self.exec_vol * market_price).sum()
|
||||
# exec_vol can be zero
|
||||
if np.isclose(self.exec_vol.sum(), 0):
|
||||
self.exec_avg_price = market_price[0]
|
||||
else:
|
||||
self.exec_avg_price = np.average(market_price, weights=self.exec_vol)
|
||||
|
||||
self.pa_twap = _price_advantage(self.exec_avg_price, self.baseline_twap, self.flow_dir)
|
||||
self.pa_vwap = _price_advantage(self.exec_avg_price, self.baseline_vwap, self.flow_dir)
|
||||
close_average = np.mean(self.market_close)
|
||||
self.pa_close = _price_advantage(self.exec_avg_price, close_average, self.flow_dir)
|
||||
|
||||
self.fulfill_rate = (self.target - self.position) / self.target_limit
|
||||
if abs(self.fulfill_rate - 1.0) < EPSILON:
|
||||
self.fulfill_rate = 1.0
|
||||
self.fulfill_rate *= 100
|
||||
|
||||
def logs(self) -> dict:
|
||||
logs = super().logs()
|
||||
logs.update(
|
||||
{
|
||||
"logs": {
|
||||
"turnover": self.turnover,
|
||||
"baseline_twap": self.baseline_twap,
|
||||
"baseline_vwap": self.baseline_vwap,
|
||||
"exec_avg_price": self.exec_avg_price,
|
||||
"pa_twap": self.pa_twap,
|
||||
"pa_vwap": self.pa_vwap,
|
||||
"pa_close": self.pa_close,
|
||||
"ffr": self.fulfill_rate,
|
||||
}
|
||||
}
|
||||
)
|
||||
return logs
|
||||
|
||||
def step(self, exec_vol: np.ndarray) -> None:
|
||||
l, r = self.next_interval()
|
||||
self.last_interval = (l, r)
|
||||
assert 0 <= l < r
|
||||
self.last_step_duration = len(exec_vol)
|
||||
self.position -= exec_vol.sum()
|
||||
assert (
|
||||
self.position > -EPSILON and (exec_vol > -EPSILON).all(),
|
||||
f"Execution volume is invalid: {exec_vol} (position = {self.position})",
|
||||
)
|
||||
self.cur_step += 1
|
||||
self.position_history[self.cur_step] = self.position
|
||||
self.cur_time += self.last_step_duration
|
||||
if self.cur_step == self.num_step: # Should reach the end of episode
|
||||
assert self.cur_time == self.end_time
|
||||
self.exec_vol = exec_vol if self.exec_vol is None else np.concatenate((self.exec_vol, exec_vol))
|
||||
|
||||
if self.on_step_end is not None:
|
||||
self.on_step_end(l, r, self)
|
||||
if self.done:
|
||||
self.update_stats()
|
||||
if self.on_episode_end is not None:
|
||||
self.on_episode_end(self)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.position < EPSILON or self.cur_step == self.num_step
|
||||
|
||||
|
||||
def _price_advantage(exec_price: float, baseline_price: float, flow: FlowDirection) -> float:
|
||||
if baseline_price == 0:
|
||||
return 0.0
|
||||
if flow == FlowDirection.ACQUIRE:
|
||||
return (1 - exec_price / baseline_price) * 10000
|
||||
else:
|
||||
return (exec_price / baseline_price - 1) * 10000
|
||||
162
qlib/rl/order_execution/from_neutrader/state_maintainer.py
Normal file
162
qlib/rl/order_execution/from_neutrader/state_maintainer.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.rl.order_execution.from_neutrader.feature import fetch_features
|
||||
from qlib.rl.order_execution.from_neutrader.state import FlowDirection, IntraDaySingleAssetDataSchema, SAOEEpisodicState
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain neutrader states taking qlib trade decisions as input.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.send_execute_result(execute_result) # in step
|
||||
# do something here
|
||||
maintainer.generate_orders(self.get_data_cal_avail_range(rtype='step'), exec_vols)
|
||||
|
||||
The states can be accessed via ``maintianer.states`` and ``maintainer.samples``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_per_step: int,
|
||||
date: pd.Timestamp,
|
||||
full_trade_range: Tuple[int, int],
|
||||
current_step: int,
|
||||
outer_trade_decision: TradeDecisionWO,
|
||||
trade_exchange: Exchange,
|
||||
) -> None:
|
||||
# The parameters look very ad-hoc right now
|
||||
self.states: Dict[Tuple[str, OrderDir], SAOEEpisodicState] = {} # explicitly make it ordered
|
||||
self.samples: Dict[Tuple[str, OrderDir], IntraDaySingleAssetDataSchema] = {}
|
||||
self.time_per_step: int = time_per_step
|
||||
self.start_time, self.end_time = full_trade_range
|
||||
self.end_time += 1 # plus 1 to align with the semantics in neutrader
|
||||
self.date: pd.Timestamp = date
|
||||
self.last_step_length: int = -1
|
||||
self.last_step_range: Optional[Tuple[int, int]] = None
|
||||
|
||||
self.order_list: List[Order] = outer_trade_decision.order_list
|
||||
self.trade_exchange: Exchange = trade_exchange
|
||||
|
||||
self.num_step = (
|
||||
self.end_time - (self.start_time - self.start_time % self.time_per_step) - 1
|
||||
) // self.time_per_step + 1
|
||||
|
||||
for order in self.order_list:
|
||||
sample = self._fetch_sample_data(order)
|
||||
state = self._create_single_ep_state(sample, current_step)
|
||||
self.samples[order.stock_id, order.direction] = sample
|
||||
self.states[order.stock_id, order.direction] = state
|
||||
|
||||
def _fetch_sample_data(self, order: Order) -> IntraDaySingleAssetDataSchema:
|
||||
start_time = self.date.replace(hour=0, minute=0, second=0)
|
||||
end_time = self.date.replace(hour=23, minute=59, second=59)
|
||||
deal_price = self.trade_exchange.get_deal_price(
|
||||
stock_id=order.stock_id, start_time=start_time, end_time=end_time, direction=order.direction, method=None,
|
||||
)
|
||||
backtest_data = fetch_features(order.stock_id, self.date, backtest=True)
|
||||
# HACK: close means deal price here. The logic is implemented in qlib.
|
||||
backtest_data["$close"] = deal_price.to_series().to_numpy()
|
||||
feature_today = fetch_features(order.stock_id, self.date)
|
||||
feature_yesterday = fetch_features(order.stock_id, self.date, yesterday=True)
|
||||
return IntraDaySingleAssetDataSchema(
|
||||
date=self.date.date(),
|
||||
stock_id=order.stock_id,
|
||||
start_time=self.start_time,
|
||||
end_time=self.end_time,
|
||||
target=max(order.amount, 0.0), # prevent target to go to -eps
|
||||
flow_dir=FlowDirection.LIQUIDATE if order.direction == 0 else FlowDirection.ACQUIRE,
|
||||
raw=backtest_data,
|
||||
processed=feature_today,
|
||||
processed_prev=feature_yesterday,
|
||||
)
|
||||
|
||||
def _create_single_ep_state(self, sample: IntraDaySingleAssetDataSchema, cur_step: int) -> SAOEEpisodicState:
|
||||
market_price = sample.raw["$close"].values
|
||||
market_vol = sample.raw["$volume"].values
|
||||
target = sample.target
|
||||
|
||||
# NOTE: Previously, market_price and market_vol are passed into the state initialization directly. Therefore,
|
||||
# the segment of market_price and market_vol are used instead of the lambda function here using the whole price
|
||||
# and vol data.
|
||||
# This refactoring is ONLY EQUIVALENT WHEN start_time/end_time passed into state is equal to
|
||||
# sample.start_time/end_time.
|
||||
# If one can confirm that these two are always the same, delete this note, please.
|
||||
state = SAOEEpisodicState(
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
self.time_per_step,
|
||||
None,
|
||||
lambda x: market_price,
|
||||
lambda: market_vol,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
target,
|
||||
target,
|
||||
sample.flow_dir,
|
||||
)
|
||||
state.cur_step = cur_step
|
||||
assert state.cur_step == 0
|
||||
return state
|
||||
|
||||
def _update_single_ep_state(
|
||||
self, state: SAOEEpisodicState, execute_result: List[Order], length: Optional[int] = None
|
||||
) -> None:
|
||||
if length is not None:
|
||||
exec_vol = np.zeros(length)
|
||||
for order, _, __, ___ in execute_result:
|
||||
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
|
||||
exec_vol[idx - self.last_step_range[0]] = order.deal_amount
|
||||
else:
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
|
||||
# sometimes exec_vol gets too large due to the rounding in exchange
|
||||
# scale the execution volume so that position won't go below 0
|
||||
# actually this case is very rare
|
||||
if exec_vol.sum() > state.position and exec_vol.sum() > 0:
|
||||
assert exec_vol.sum() < state.position + 1, f"{exec_vol} too large for {state}"
|
||||
exec_vol *= state.position / (exec_vol.sum())
|
||||
|
||||
state.step(exec_vol)
|
||||
|
||||
def create_sub_order(self, exec_vol: float, original_order: Order) -> Order:
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
return oh.create(original_order.stock_id, exec_vol, original_order.direction)
|
||||
|
||||
def send_execute_result(self, execute_result: Optional[List[Any]]) -> None:
|
||||
if self.last_step_length < 0:
|
||||
assert not execute_result
|
||||
return
|
||||
orders = defaultdict(list)
|
||||
if execute_result is not None:
|
||||
for e in execute_result:
|
||||
orders[e[0].stock_id, e[0].direction].append(e)
|
||||
for (stock_id, direction), state in self.states.items():
|
||||
self._update_single_ep_state(state, orders[stock_id, direction], self.last_step_length)
|
||||
|
||||
def generate_orders(self, step_trade_range: Tuple[int, int], exec_vols: List[float]) -> List[Order]:
|
||||
order_list = []
|
||||
|
||||
assert len(exec_vols) == len(self.order_list)
|
||||
for v, o in zip(exec_vols, self.order_list):
|
||||
if v > 0:
|
||||
order_list.append(self.create_sub_order(v, o))
|
||||
|
||||
step_start_time, step_end_time = step_trade_range # inclusive
|
||||
step_end_time += 1
|
||||
|
||||
self.last_step_length = step_end_time - step_start_time
|
||||
self.last_step_range = (step_start_time, step_end_time)
|
||||
|
||||
return order_list
|
||||
91
qlib/rl/order_execution/from_neutrader/strategy.py
Normal file
91
qlib/rl/order_execution/from_neutrader/strategy.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from abc import ABCMeta
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.rl.order_execution.from_neutrader.state import IntraDaySingleAssetDataSchema, SAOEEpisodicState
|
||||
from qlib.rl.order_execution.from_neutrader.state_maintainer import StateMaintainer
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class RLStrategyBase(BaseStrategy, metaclass=ABCMeta):
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
"""
|
||||
post process for each step of strategy this is design for RL Strategy,
|
||||
which require to update the policy state after each step
|
||||
|
||||
NOTE: it is strongly coupled with RLNestedExecutor;
|
||||
"""
|
||||
raise NotImplementedError("Please implement the `post_exe_step` method")
|
||||
|
||||
|
||||
class DecomposedStrategy(RLStrategyBase):
|
||||
def __init__(self):
|
||||
super(DecomposedStrategy, self).__init__()
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
time_per_step = int(pd.Timedelta(self.trade_calendar.get_freq()) / pd.Timedelta("1min"))
|
||||
if outer_trade_decision is not None:
|
||||
self.maintainer = StateMaintainer(
|
||||
time_per_step,
|
||||
self.trade_calendar.get_all_time()[0],
|
||||
self.get_data_cal_avail_range(),
|
||||
self.trade_calendar.get_trade_step(),
|
||||
outer_trade_decision,
|
||||
self.trade_exchange,
|
||||
)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result):
|
||||
self.maintainer.send_execute_result(execute_result)
|
||||
|
||||
@property
|
||||
def sample_state_pair(self) -> Tuple[IntraDaySingleAssetDataSchema, SAOEEpisodicState]:
|
||||
assert len(self.maintainer.samples) == len(self.maintainer.states) == 1
|
||||
return (
|
||||
list(self.maintainer.samples.values())[0],
|
||||
list(self.maintainer.states.values())[0],
|
||||
)
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
# get a decision from the outmost loop
|
||||
exec_vol = yield self
|
||||
|
||||
return TradeDecisionWO(
|
||||
self.maintainer.generate_orders(self.get_data_cal_avail_range(rtype="step"), [exec_vol]), self
|
||||
)
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
# this logic is copied from FileOrderStrategy
|
||||
def __init__(
|
||||
self,
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> None:
|
||||
super().__init__(common_infra=common_infra)
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
self._instrument = instrument
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._instrument,
|
||||
amount=self._order.amount,
|
||||
direction=Order.parse_dir(self._order.direction),
|
||||
)
|
||||
]
|
||||
trade_decision = TradeDecisionWO(order_list, self, self._trade_range)
|
||||
return trade_decision
|
||||
@@ -2,34 +2,27 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from plistlib import Dict
|
||||
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Generator, List, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from gym.vector.utils import spaces
|
||||
|
||||
from qlib.backtest import Account, BaseExecutor, CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.decision import Order
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.backtest import get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import Order, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.config import QlibConfig
|
||||
from qlib.rl.simulator import ActType, Simulator, StateType
|
||||
from qlib.rl.interpreter import ActionInterpreter
|
||||
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
|
||||
from qlib.rl.order_execution.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.from_neutrader.state import SAOEEpisodicState
|
||||
from qlib.rl.order_execution.from_neutrader.strategy import DecomposedStrategy
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: Union[float, Dict[str, Tuple[str, str]]]
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.
|
||||
trade_unit: Optional[float] = 100.
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_start_time: pd.Timestamp,
|
||||
@@ -69,13 +62,31 @@ def get_common_infra(
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
|
||||
|
||||
class QlibSimulator(Simulator[Order, StateType, ActType]):
|
||||
class CategoricalActionInterpreter(ActionInterpreter[SAOEEpisodicState, int, float]):
|
||||
def __init__(self, values: Union[int, List[float]]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Discrete:
|
||||
return spaces.Discrete(len(self.action_values))
|
||||
|
||||
def interpret(self, state: SAOEEpisodicState, action: int) -> float:
|
||||
volume = min(state.position, state.target * self.action_values[action])
|
||||
if state.cur_step + 1 >= state.num_step:
|
||||
volume = state.position # execute all volumes at last
|
||||
return volume
|
||||
|
||||
|
||||
class QlibSimulator(Simulator[Order, SAOEEpisodicState, float]):
|
||||
def __init__(
|
||||
self,
|
||||
qlib_config: QlibConfig,
|
||||
time_per_step: str,
|
||||
top_strategy: BaseStrategy,
|
||||
inner_strategy_fn: Callable[[], BaseStrategy],
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
qlib_config: QlibConfig,
|
||||
top_strategy_fn: Callable[[CommonInfrastructure, Order, TradeRange, str], BaseStrategy],
|
||||
inner_executor_fn: Callable[[CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
) -> None:
|
||||
@@ -83,61 +94,77 @@ class QlibSimulator(Simulator[Order, StateType, ActType]):
|
||||
initial=None, # TODO
|
||||
)
|
||||
|
||||
self.qlib_config = qlib_config
|
||||
|
||||
self._trade_range = TradeRangeByTime(start_time, end_time)
|
||||
self._qlib_config = qlib_config
|
||||
self._time_per_step = time_per_step
|
||||
self._top_strategy = top_strategy
|
||||
self._top_strategy_fn = top_strategy_fn
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._inner_strategy_fn = inner_strategy_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._executor: Optional[RLNestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
|
||||
self._done = False
|
||||
|
||||
def _reset(
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
def reset(
|
||||
self,
|
||||
instrument: str,
|
||||
date_time: pd.Timestamp,
|
||||
order: Order,
|
||||
instrument: str = "SH600000", # TODO: Test only. Remove this default value later.
|
||||
) -> None:
|
||||
# TODO: init_qlib
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_start_time=date_time,
|
||||
trade_end_time=date_time,
|
||||
trade_start_time=order.start_time,
|
||||
trade_end_time=order.end_time,
|
||||
codes=[instrument],
|
||||
)
|
||||
|
||||
self._executor = NestedExecutor(
|
||||
self._executor = RLNestedExecutor(
|
||||
time_per_step=self._time_per_step,
|
||||
inner_executor=self._inner_executor_fn(common_infra),
|
||||
inner_strategy=self._inner_strategy_fn(),
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
)
|
||||
|
||||
self._executor.reset(start_time=date_time, end_time=date_time)
|
||||
self._top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
self._collect_data_loop = self._executor.collect_data(self._top_strategy.generate_trade_decision(), level=0)
|
||||
top_strategy = self._top_strategy_fn(common_infra, order, self._trade_range, instrument)
|
||||
|
||||
self._executor.reset(start_time=order.start_time, end_time=order.end_time)
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
strategy = self._iter_strategy(action=None)
|
||||
sample, ep_state = strategy.sample_state_pair
|
||||
self._last_ep_state = ep_state
|
||||
|
||||
self._done = False
|
||||
|
||||
def step(self, action: ActType) -> None:
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, DecomposedStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, DecomposedStrategy)
|
||||
return strategy
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
try:
|
||||
strategy = self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, BaseStrategy):
|
||||
strategy = self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, BaseStrategy)
|
||||
|
||||
# TODO: do something here
|
||||
|
||||
strategy = self._iter_strategy(action=action)
|
||||
sample, ep_state = strategy.sample_state_pair
|
||||
except StopIteration:
|
||||
sample, ep_state = self._inner_strategy.sample_state_pair
|
||||
assert ep_state.done
|
||||
|
||||
self._last_ep_state = ep_state
|
||||
if ep_state.done:
|
||||
self._done = True
|
||||
|
||||
def get_state(self) -> StateType:
|
||||
pass # TODO: Collect info from executor. Generate state.
|
||||
def get_state(self) -> SAOEEpisodicState:
|
||||
return self._last_ep_state
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
0
qlib/rl/order_execution/tests/__init__.py
Normal file
0
qlib/rl/order_execution/tests/__init__.py
Normal file
133
qlib/rl/order_execution/tests/test_simulator_qlib.py
Normal file
133
qlib/rl/order_execution/tests/test_simulator_qlib.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import collections
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeRange
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.config import QlibConfig
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
|
||||
from qlib.rl.order_execution.from_neutrader.strategy import SingleOrderStrategy
|
||||
from qlib.rl.order_execution.simulator_qlib import CategoricalActionInterpreter, ExchangeConfig, QlibSimulator
|
||||
|
||||
qlib_config = QlibConfig(
|
||||
{
|
||||
"provider_uri_day": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1d"),
|
||||
"provider_uri_1min": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1min"),
|
||||
"feature_root_dir": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_handler_stock"),
|
||||
"feature_columns_today": [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
|
||||
],
|
||||
"feature_columns_yesterday": [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=('$ask == 0', '$bid == 0'),
|
||||
deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'),
|
||||
volume_threshold={
|
||||
'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"),
|
||||
'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1')
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
cash_limit=None,
|
||||
generate_report=False,
|
||||
)
|
||||
|
||||
|
||||
def _top_strategy_fn(
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> SingleOrderStrategy:
|
||||
return SingleOrderStrategy(common_infra, order, trade_range, instrument)
|
||||
|
||||
|
||||
def _inner_executor_fn(common_infra: CommonInfrastructure) -> RLNestedExecutor:
|
||||
return RLNestedExecutor(
|
||||
time_per_step="30min",
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
verbose=False,
|
||||
trade_type=SimulatorExecutor.TT_SERIAL,
|
||||
generate_report=False,
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
),
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
)
|
||||
|
||||
|
||||
def test():
|
||||
order_infos = [
|
||||
("2019-03-04", 1078.644160270691, 1),
|
||||
("2019-03-11", 32.440425872802734, 1),
|
||||
("2019-03-25", 40.55053234100342, 0),
|
||||
("2019-04-01", 1070.5340538024902, 0),
|
||||
("2019-05-27", 300.0739393234253, 1),
|
||||
("2019-06-03", 8.110106468200684, 0),
|
||||
("2019-06-11", 0.9360466003417968, 0),
|
||||
("2019-06-17", 794.4272003173828, 1),
|
||||
("2019-06-24", 7.865615844726562, 0),
|
||||
("2019-07-01", 1077.589370727539, 0),
|
||||
("2021-01-04", 499.7846999168396, 1),
|
||||
("2021-01-11", 14.918946266174316, 0),
|
||||
("2021-01-18", 484.8657536506653, 0),
|
||||
("2021-02-08", 537.0820655822754, 1),
|
||||
("2021-02-18", 7.459473133087158, 0),
|
||||
("2021-02-22", 7.459473133087158, 0),
|
||||
("2021-03-01", 14.918946266174316, 1),
|
||||
("2021-03-08", 872.7583565711975, 1),
|
||||
|
||||
]
|
||||
orders = collections.deque([
|
||||
Order(
|
||||
stock_id="",
|
||||
amount=info[1],
|
||||
direction=OrderDir(info[2]),
|
||||
start_time=pd.Timestamp(info[0]),
|
||||
end_time=pd.Timestamp(info[0]),
|
||||
)
|
||||
for info in order_infos
|
||||
])
|
||||
|
||||
# fmt: off
|
||||
simulator = QlibSimulator(
|
||||
time_per_step="1day",
|
||||
start_time="9:45",
|
||||
end_time="14:44",
|
||||
qlib_config=qlib_config,
|
||||
top_strategy_fn=_top_strategy_fn,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
action_interpreter = CategoricalActionInterpreter(values=4)
|
||||
|
||||
simulator.reset(orders.popleft())
|
||||
|
||||
for i in range(10):
|
||||
print(f"Step {i}")
|
||||
ep_state = simulator.get_state()
|
||||
action = action_interpreter(ep_state, 1)
|
||||
|
||||
simulator.step(action)
|
||||
if simulator.done():
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
Reference in New Issue
Block a user