1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Simulator & action interpreter

This commit is contained in:
Default
2022-06-27 15:50:48 +08:00
committed by Huoran Li
parent e23504c1d7
commit 934840146b
13 changed files with 1143 additions and 55 deletions

View File

@@ -349,4 +349,4 @@ def format_decisions(
return res
__all__ = ["Order", "backtest", "BaseExecutor", "CommonInfrastructure"]
__all__ = ["Order", "backtest"]

View File

@@ -179,8 +179,8 @@ class OrderHelper:
return Order(
stock_id=code,
amount=amount,
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
start_time=None if start_time is None else pd.Timestamp(start_time),
end_time=None if end_time is None else pd.Timestamp(end_time),
direction=direction,
)
@@ -530,7 +530,7 @@ class TradeDecisionWO(BaseTradeDecision):
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
start, end = strategy.trade_calendar.get_step_time()

View File

@@ -0,0 +1,26 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
@dataclass
class RuntimeConfig:
seed: int = 42
output_dir: Optional[Path] = None
checkpoint_dir: Optional[Path] = None
tb_log_dir: Optional[Path] = None
debug: bool = False
use_cuda: bool = True
@dataclass
class ExchangeConfig:
limit_threshold: Union[float, Tuple[str, str]]
deal_price: Union[str, Tuple[str, str]]
volume_threshold: dict
open_cost: float = 0.0005
close_cost: float = 0.0015
min_cost: float = 5.
trade_unit: Optional[float] = 100.
cash_limit: Optional[Union[Path, float]] = None
generate_report: bool = False

View File

@@ -0,0 +1,11 @@
from typing import List
from qlib.backtest.executor import NestedExecutor
from .strategy import RLStrategyBase
class RLNestedExecutor(NestedExecutor):
# RL nested executor
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
if isinstance(self.inner_strategy, RLStrategyBase):
self.inner_strategy.post_exe_step(inner_exe_res)

View File

@@ -0,0 +1,164 @@
import collections
from dataclasses import dataclass
import numpy as np
import pickle
from pathlib import Path
from typing import Optional, List
import pandas as pd
import qlib
from .highfreq_ops import DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
from qlib.contrib.ops.high_freq import DayCumsum
from qlib.config import REG_CN
from qlib.data.dataset import DatasetH
@dataclass
class QlibConfig:
provider_uri_day: Path
provider_uri_1min: Path
feature_root_dir: Path
feature_columns_today: List[str]
feature_columns_yesterday: List[str]
_dataset = None
class LRUCache:
def __init__(self, pool_size: int = 200):
self.pool_size = pool_size
self.contents = dict()
self.keys = collections.deque()
def put(self, key, item):
if self.has(key):
self.keys.remove(key)
self.keys.append(key)
self.contents[key] = item
while len(self.contents) > self.pool_size:
self.contents.pop(self.keys.popleft())
def get(self, key):
return self.contents[key]
def has(self, key):
return key in self.contents
class DataWrapper:
def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH,
columns_today: List[str], columns_yesterday: List[str], _internal: bool = False):
assert _internal, 'Init function of data wrapper is for internal use only.'
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
self.columns_today = columns_today
self.columns_yesterday = columns_yesterday
self.feature_cache = LRUCache()
self.backtest_cache = LRUCache()
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False):
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = self.backtest_dataset if backtest else self.feature_dataset
if backtest:
dataset = self.backtest_dataset
cache = self.backtest_cache
else:
dataset = self.feature_dataset
cache = self.feature_cache
if cache.has((start_time, end_time, stock_id)):
return cache.get((start_time, end_time, stock_id))
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
cache.put((start_time, end_time, stock_id), data)
return data
def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
global _dataset
provider_uri_map = {
"day": config.provider_uri_day.as_posix(),
"1min": config.provider_uri_1min.as_posix(),
}
qlib.init(
region=REG_CN,
auto_mount=False,
custom_ops=[DayLast, FFillNan, BFillNan,
Date, Select, IsNull, IsInf, Cut, DayCumsum],
expression_cache=None,
calendar_provider={
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
feature_provider={
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
provider_uri=provider_uri_map,
kernels=1,
redis_port=-1,
clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance
)
# this won't work if it's put outside in case of multiprocessing
from qlib.data import D
if part is None:
feature_path = config.feature_root_dir / 'feature.pkl'
backtest_path = config.feature_root_dir / 'backtest.pkl'
else:
feature_path = config.feature_root_dir / 'feature' / (part + '.pkl')
backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl')
with feature_path.open('rb') as f:
print(feature_path)
feature_dataset = pickle.load(f)
with backtest_path.open('rb') as f:
backtest_dataset = pickle.load(f)
_dataset = DataWrapper(
feature_dataset,
backtest_dataset,
config.feature_columns_today,
config.feature_columns_yesterday,
_internal=True
)
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False):
assert _dataset is not None, 'You must call init_qlib() before doing this.'
if backtest:
fields = ['$close', '$volume']
else:
fields = _dataset.columns_yesterday if yesterday else _dataset.columns_today
data = _dataset.get(stock_id, date, backtest)
if data is None or len(data) == 0:
# create a fake index, but RL doesn't care about index
data = pd.DataFrame(0., index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
else:
data = data.rename(columns={c: c.rstrip('0') for c in data.columns})
data = data[fields]
return data

View File

@@ -0,0 +1,223 @@
import numpy as np
import pandas as pd
from qlib.data.cache import H
from qlib.data.data import Cal
from qlib.data.ops import ElemOperator, PairOperator
def get_calendar_day(freq="day", future=False):
"""Load High-Freq Calendar Date Using Memcache.
Parameters
----------
freq : str
frequency of read calendar file.
future : bool
whether including future trading day.
Returns
-------
_calendar:
array of date.
"""
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(
list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar
def get_calendar_minute(freq='day', future=False):
"""Load High-Freq Calendar Minute Using Memcache"""
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(
list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar
class DayLast(ElemOperator):
"""DayLast Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value equals the last value of its day
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return series.groupby(_calendar[series.index]).transform("last")
class FFillNan(ElemOperator):
"""FFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a forward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
"""BFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a backfoward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
"""Date Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value is the date corresponding to feature.index
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return pd.Series(_calendar[series.index], index=series.index)
class Select(PairOperator):
"""Select Operator
Parameters
----------
feature_left : Expression
feature instance, select condition
feature_right : Expression
feature instance, select value
Returns
----------
feature:
value(feature_right) that meets the condition(feature_left)
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(
instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(
instrument, start_index, end_index, freq)
return series_feature.loc[series_condition]
class IsNull(ElemOperator):
"""IsNull Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
A series indicating whether the feature is nan
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()
class IsInf(ElemOperator):
"""IsInf Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
A series indicating whether the feature is inf
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return np.isinf(series)
class Cut(ElemOperator):
"""Cut Operator
Parameters
----------
feature : Expression
feature instance
l : int
l > 0, delete the first l elements of feature (default is None, which means 0)
r : int
r < 0, delete the last -r elements of feature (default is None, which means 0)
Returns
----------
feature:
A series with the first l and last -r elements deleted from the feature.
Note: It is deleted from the raw data, not the sliced data
"""
def __init__(self, feature, left=None, right=None):
self.left = left
self.right = right
if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0):
raise ValueError("Cut operator l shoud > 0 and r should < 0")
super(Cut, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.iloc[self.left: self.right]
def get_extended_window_size(self):
ll = 0 if self.left is None else self.left
rr = 0 if self.right is None else abs(self.right)
lft_etd, rght_etd = self.feature.get_extended_window_size()
lft_etd = lft_etd + ll
rght_etd = rght_etd + rr
return lft_etd, rght_etd

View File

@@ -0,0 +1,251 @@
import abc
import math
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Callable, Literal, Optional, Tuple
import numpy as np
import pandas as pd
EPSILON = 1e-7
class FlowDirection(str, Enum):
ACQUIRE = "acquire"
LIQUIDATE = "liquidate"
def _round_time(time: int, granularity: int) -> int:
return time - time % granularity
@dataclass
class BaseEpisodicState(abc.ABC):
"""
Base class for episodic states.
"""
# requirements
start_time: int
end_time: int
time_per_step: int
vol_limit: Optional[float] # TODO: meaning?
price_func: Callable[[str], np.ndarray] # TODO: meaning?
volume_func: Callable[[], np.ndarray] # TODO: meaning?
on_step_end: Optional[Callable[..., None]] # TODO: meaning?
on_episode_end: Optional[Callable[..., None]] # TODO: meaning?
asset_num: int # TODO: meaning?
# agent states
num_step: int = field(init=False) # Number of steps
cur_time: int = field(init=False) # Current time
cur_step: int = field(init=False, default=0)
exec_vol: Optional[np.ndarray] = field(init=False, default=None) # Execution history
last_step_duration: int = field(init=False)
position: float = field(init=False)
position_history: np.ndarray = field(init=False)
def __post_init__(self) -> None:
self.cur_time = self.start_time
rounded_start_time = _round_time(self.start_time, self.time_per_step)
# TODO: why not rounding end time?
self.num_step = math.floor((self.end_time - rounded_start_time) / self.time_per_step)
def logs(self) -> dict:
# Base logging information shared across all subclasses.
# You can call logs = super().logs() to get these default logs and use logs.update(...) to add other logging
# information or override it completely to remove these logging fields.
return {
"logs": {
"stop_time": self.cur_time - self.start_time,
"stop_step": self.cur_step,
},
"history": {
"volume": self.execution_history(),
},
}
def execution_history(self) -> np.ndarray:
return np.pad(self.exec_vol, (0, self.end_time - self.start_time - len(self.exec_vol)))
def next_duration(self) -> int:
left, right = self.next_interval()
return right - left
def next_interval(self) -> Tuple[int, int]:
left = _round_time(self.cur_time, self.time_per_step)
right = left + self.time_per_step
return max(left, self.start_time) - self.start_time, min(right, self.end_time) - self.start_time
@classmethod
def get_init_field_names(cls):
ret = []
for f in fields(cls):
if f.init:
ret.append(f.name)
return ret
@abc.abstractmethod
def step(self, *args, **kwargs):
raise NotImplementedError()
@property
def done(self) -> bool:
return False
@dataclass
class IntraDaySingleAssetDataSchema:
"""
In the current context, raw should be a DataFrame with `datetime` as index and
(at least) `$vwap0`, `$volume0`, `$close0` as columns.
`processed` should be a DataFrame of 240x6, which is the same as `processed_prev`.
"""
date: pd.Timestamp
stock_id: str
start_time: int
end_time: int
target: float
flow_dir: FlowDirection
raw: pd.DataFrame
processed: pd.DataFrame
processed_prev: pd.DataFrame
def get_price(self, type: Literal['deal', 'close'] = 'deal'):
if type == 'deal':
return self.raw['$price'].values
elif type == 'close':
return self.raw['$close0'].values
def get_volume(self):
return self.raw['$volume0'].values
def get_processed_data(self, type: Literal['today', 'yesterday'] = 'today'):
if type == 'today':
return self.processed.to_numpy()
elif type == 'yesterday':
return self.processed_prev.to_numpy()
@dataclass
class SAOEEpisodicState(BaseEpisodicState):
"""
Global state of the whole time horizon.
"""
# requirements
target: float
target_limit: float
flow_dir: FlowDirection
# calculated statistics
turnover: Optional[float] = field(init=False)
baseline_twap: Optional[float] = field(init=False)
baseline_vwap: Optional[float] = field(init=False)
exec_avg_price: Optional[float] = field(init=False)
pa_twap: Optional[float] = field(init=False)
pa_vwap: Optional[float] = field(init=False)
pa_close: Optional[float] = field(init=False)
fulfill_rate: Optional[float] = field(init=False)
market_price: np.ndarray = field(init=False) # deal price, might be different from close
market_close: np.ndarray = field(init=False) # close price
market_volume: np.ndarray = field(init=False)
# NOTE: this is a temporary design to make it compatible with old qlib integration framework. As long as callback
# functions are passed correctly, this field should be removed from this class.
last_interval: Tuple[int, int] = field(default=(0, 0), init=False)
def __post_init__(self) -> None:
assert self.target >= 0
assert self.asset_num == 1
super().__post_init__()
self.market_volume = self.volume_func()[self.start_time : self.end_time]
self.market_price = self.price_func("deal")[self.start_time : self.end_time]
self.market_close = self.price_func("close")[self.start_time : self.end_time]
self.position = self.target
self.position_history = np.full((self.num_step + 1), np.nan)
self.position_history[0] = self.position
self.baseline_twap = np.mean(self.market_price)
if self.market_volume.sum() == 0:
self.baseline_vwap = self.baseline_twap
else:
self.baseline_vwap = np.average(self.market_price, weights=self.market_volume)
def update_stats(self) -> None:
market_price = self.market_price[: len(self.exec_vol)]
self.turnover = (self.exec_vol * market_price).sum()
# exec_vol can be zero
if np.isclose(self.exec_vol.sum(), 0):
self.exec_avg_price = market_price[0]
else:
self.exec_avg_price = np.average(market_price, weights=self.exec_vol)
self.pa_twap = _price_advantage(self.exec_avg_price, self.baseline_twap, self.flow_dir)
self.pa_vwap = _price_advantage(self.exec_avg_price, self.baseline_vwap, self.flow_dir)
close_average = np.mean(self.market_close)
self.pa_close = _price_advantage(self.exec_avg_price, close_average, self.flow_dir)
self.fulfill_rate = (self.target - self.position) / self.target_limit
if abs(self.fulfill_rate - 1.0) < EPSILON:
self.fulfill_rate = 1.0
self.fulfill_rate *= 100
def logs(self) -> dict:
logs = super().logs()
logs.update(
{
"logs": {
"turnover": self.turnover,
"baseline_twap": self.baseline_twap,
"baseline_vwap": self.baseline_vwap,
"exec_avg_price": self.exec_avg_price,
"pa_twap": self.pa_twap,
"pa_vwap": self.pa_vwap,
"pa_close": self.pa_close,
"ffr": self.fulfill_rate,
}
}
)
return logs
def step(self, exec_vol: np.ndarray) -> None:
l, r = self.next_interval()
self.last_interval = (l, r)
assert 0 <= l < r
self.last_step_duration = len(exec_vol)
self.position -= exec_vol.sum()
assert (
self.position > -EPSILON and (exec_vol > -EPSILON).all(),
f"Execution volume is invalid: {exec_vol} (position = {self.position})",
)
self.cur_step += 1
self.position_history[self.cur_step] = self.position
self.cur_time += self.last_step_duration
if self.cur_step == self.num_step: # Should reach the end of episode
assert self.cur_time == self.end_time
self.exec_vol = exec_vol if self.exec_vol is None else np.concatenate((self.exec_vol, exec_vol))
if self.on_step_end is not None:
self.on_step_end(l, r, self)
if self.done:
self.update_stats()
if self.on_episode_end is not None:
self.on_episode_end(self)
@property
def done(self) -> bool:
return self.position < EPSILON or self.cur_step == self.num_step
def _price_advantage(exec_price: float, baseline_price: float, flow: FlowDirection) -> float:
if baseline_price == 0:
return 0.0
if flow == FlowDirection.ACQUIRE:
return (1 - exec_price / baseline_price) * 10000
else:
return (exec_price / baseline_price - 1) * 10000

View File

@@ -0,0 +1,162 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
from qlib.backtest.exchange import Exchange
from qlib.constant import REG_CN
from qlib.rl.order_execution.from_neutrader.feature import fetch_features
from qlib.rl.order_execution.from_neutrader.state import FlowDirection, IntraDaySingleAssetDataSchema, SAOEEpisodicState
from qlib.utils.time import get_day_min_idx_range
class StateMaintainer:
"""
Maintain neutrader states taking qlib trade decisions as input.
Example usage::
maintainer = StateMaintainer(...) # in reset
maintainer.send_execute_result(execute_result) # in step
# do something here
maintainer.generate_orders(self.get_data_cal_avail_range(rtype='step'), exec_vols)
The states can be accessed via ``maintianer.states`` and ``maintainer.samples``.
"""
def __init__(
self,
time_per_step: int,
date: pd.Timestamp,
full_trade_range: Tuple[int, int],
current_step: int,
outer_trade_decision: TradeDecisionWO,
trade_exchange: Exchange,
) -> None:
# The parameters look very ad-hoc right now
self.states: Dict[Tuple[str, OrderDir], SAOEEpisodicState] = {} # explicitly make it ordered
self.samples: Dict[Tuple[str, OrderDir], IntraDaySingleAssetDataSchema] = {}
self.time_per_step: int = time_per_step
self.start_time, self.end_time = full_trade_range
self.end_time += 1 # plus 1 to align with the semantics in neutrader
self.date: pd.Timestamp = date
self.last_step_length: int = -1
self.last_step_range: Optional[Tuple[int, int]] = None
self.order_list: List[Order] = outer_trade_decision.order_list
self.trade_exchange: Exchange = trade_exchange
self.num_step = (
self.end_time - (self.start_time - self.start_time % self.time_per_step) - 1
) // self.time_per_step + 1
for order in self.order_list:
sample = self._fetch_sample_data(order)
state = self._create_single_ep_state(sample, current_step)
self.samples[order.stock_id, order.direction] = sample
self.states[order.stock_id, order.direction] = state
def _fetch_sample_data(self, order: Order) -> IntraDaySingleAssetDataSchema:
start_time = self.date.replace(hour=0, minute=0, second=0)
end_time = self.date.replace(hour=23, minute=59, second=59)
deal_price = self.trade_exchange.get_deal_price(
stock_id=order.stock_id, start_time=start_time, end_time=end_time, direction=order.direction, method=None,
)
backtest_data = fetch_features(order.stock_id, self.date, backtest=True)
# HACK: close means deal price here. The logic is implemented in qlib.
backtest_data["$close"] = deal_price.to_series().to_numpy()
feature_today = fetch_features(order.stock_id, self.date)
feature_yesterday = fetch_features(order.stock_id, self.date, yesterday=True)
return IntraDaySingleAssetDataSchema(
date=self.date.date(),
stock_id=order.stock_id,
start_time=self.start_time,
end_time=self.end_time,
target=max(order.amount, 0.0), # prevent target to go to -eps
flow_dir=FlowDirection.LIQUIDATE if order.direction == 0 else FlowDirection.ACQUIRE,
raw=backtest_data,
processed=feature_today,
processed_prev=feature_yesterday,
)
def _create_single_ep_state(self, sample: IntraDaySingleAssetDataSchema, cur_step: int) -> SAOEEpisodicState:
market_price = sample.raw["$close"].values
market_vol = sample.raw["$volume"].values
target = sample.target
# NOTE: Previously, market_price and market_vol are passed into the state initialization directly. Therefore,
# the segment of market_price and market_vol are used instead of the lambda function here using the whole price
# and vol data.
# This refactoring is ONLY EQUIVALENT WHEN start_time/end_time passed into state is equal to
# sample.start_time/end_time.
# If one can confirm that these two are always the same, delete this note, please.
state = SAOEEpisodicState(
self.start_time,
self.end_time,
self.time_per_step,
None,
lambda x: market_price,
lambda: market_vol,
None,
None,
1,
target,
target,
sample.flow_dir,
)
state.cur_step = cur_step
assert state.cur_step == 0
return state
def _update_single_ep_state(
self, state: SAOEEpisodicState, execute_result: List[Order], length: Optional[int] = None
) -> None:
if length is not None:
exec_vol = np.zeros(length)
for order, _, __, ___ in execute_result:
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
exec_vol[idx - self.last_step_range[0]] = order.deal_amount
else:
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
# sometimes exec_vol gets too large due to the rounding in exchange
# scale the execution volume so that position won't go below 0
# actually this case is very rare
if exec_vol.sum() > state.position and exec_vol.sum() > 0:
assert exec_vol.sum() < state.position + 1, f"{exec_vol} too large for {state}"
exec_vol *= state.position / (exec_vol.sum())
state.step(exec_vol)
def create_sub_order(self, exec_vol: float, original_order: Order) -> Order:
oh = self.trade_exchange.get_order_helper()
return oh.create(original_order.stock_id, exec_vol, original_order.direction)
def send_execute_result(self, execute_result: Optional[List[Any]]) -> None:
if self.last_step_length < 0:
assert not execute_result
return
orders = defaultdict(list)
if execute_result is not None:
for e in execute_result:
orders[e[0].stock_id, e[0].direction].append(e)
for (stock_id, direction), state in self.states.items():
self._update_single_ep_state(state, orders[stock_id, direction], self.last_step_length)
def generate_orders(self, step_trade_range: Tuple[int, int], exec_vols: List[float]) -> List[Order]:
order_list = []
assert len(exec_vols) == len(self.order_list)
for v, o in zip(exec_vols, self.order_list):
if v > 0:
order_list.append(self.create_sub_order(v, o))
step_start_time, step_end_time = step_trade_range # inclusive
step_end_time += 1
self.last_step_length = step_end_time - step_start_time
self.last_step_range = (step_start_time, step_end_time)
return order_list

View File

@@ -0,0 +1,91 @@
from abc import ABCMeta
from typing import Tuple
import pandas as pd
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange
from qlib.backtest.utils import CommonInfrastructure
from qlib.rl.order_execution.from_neutrader.state import IntraDaySingleAssetDataSchema, SAOEEpisodicState
from qlib.rl.order_execution.from_neutrader.state_maintainer import StateMaintainer
from qlib.strategy.base import BaseStrategy
class RLStrategyBase(BaseStrategy, metaclass=ABCMeta):
def post_exe_step(self, execute_result: list) -> None:
"""
post process for each step of strategy this is design for RL Strategy,
which require to update the policy state after each step
NOTE: it is strongly coupled with RLNestedExecutor;
"""
raise NotImplementedError("Please implement the `post_exe_step` method")
class DecomposedStrategy(RLStrategyBase):
def __init__(self):
super(DecomposedStrategy, self).__init__()
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
time_per_step = int(pd.Timedelta(self.trade_calendar.get_freq()) / pd.Timedelta("1min"))
if outer_trade_decision is not None:
self.maintainer = StateMaintainer(
time_per_step,
self.trade_calendar.get_all_time()[0],
self.get_data_cal_avail_range(),
self.trade_calendar.get_trade_step(),
outer_trade_decision,
self.trade_exchange,
)
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
return outer_trade_decision
def post_exe_step(self, execute_result):
self.maintainer.send_execute_result(execute_result)
@property
def sample_state_pair(self) -> Tuple[IntraDaySingleAssetDataSchema, SAOEEpisodicState]:
assert len(self.maintainer.samples) == len(self.maintainer.states) == 1
return (
list(self.maintainer.samples.values())[0],
list(self.maintainer.states.values())[0],
)
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
# get a decision from the outmost loop
exec_vol = yield self
return TradeDecisionWO(
self.maintainer.generate_orders(self.get_data_cal_avail_range(rtype="step"), [exec_vol]), self
)
class SingleOrderStrategy(BaseStrategy):
# this logic is copied from FileOrderStrategy
def __init__(
self,
common_infra: CommonInfrastructure,
order: Order,
trade_range: TradeRange,
instrument: str,
) -> None:
super().__init__(common_infra=common_infra)
self._order = order
self._trade_range = trade_range
self._instrument = instrument
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
return outer_trade_decision
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
order_list = [
oh.create(
code=self._instrument,
amount=self._order.amount,
direction=Order.parse_dir(self._order.direction),
)
]
trade_decision = TradeDecisionWO(order_list, self, self._trade_range)
return trade_decision

View File

@@ -2,34 +2,27 @@
# Licensed under the MIT License.
"""Placeholder for qlib-based simulator."""
from dataclasses import dataclass
from pathlib import Path
from plistlib import Dict
from typing import Callable, Generator, List, Optional, Tuple, Union
from typing import Callable, Generator, List, Optional, Union
import pandas as pd
from gym.vector.utils import spaces
from qlib.backtest import Account, BaseExecutor, CommonInfrastructure, get_exchange
from qlib.backtest.decision import Order
from qlib.backtest.executor import NestedExecutor
from qlib.backtest import get_exchange
from qlib.backtest.account import Account
from qlib.backtest.decision import Order, TradeRange, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor
from qlib.backtest.utils import CommonInfrastructure
from qlib.config import QlibConfig
from qlib.rl.simulator import ActType, Simulator, StateType
from qlib.rl.interpreter import ActionInterpreter
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
from qlib.rl.order_execution.from_neutrader.feature import init_qlib
from qlib.rl.order_execution.from_neutrader.state import SAOEEpisodicState
from qlib.rl.order_execution.from_neutrader.strategy import DecomposedStrategy
from qlib.rl.simulator import Simulator
from qlib.strategy.base import BaseStrategy
@dataclass
class ExchangeConfig:
limit_threshold: Union[float, Tuple[str, str]]
deal_price: Union[str, Tuple[str, str]]
volume_threshold: Union[float, Dict[str, Tuple[str, str]]]
open_cost: float = 0.0005
close_cost: float = 0.0015
min_cost: float = 5.
trade_unit: Optional[float] = 100.
cash_limit: Optional[Union[Path, float]] = None
generate_report: bool = False
def get_common_infra(
config: ExchangeConfig,
trade_start_time: pd.Timestamp,
@@ -69,13 +62,31 @@ def get_common_infra(
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
class QlibSimulator(Simulator[Order, StateType, ActType]):
class CategoricalActionInterpreter(ActionInterpreter[SAOEEpisodicState, int, float]):
def __init__(self, values: Union[int, List[float]]) -> None:
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@property
def action_space(self) -> spaces.Discrete:
return spaces.Discrete(len(self.action_values))
def interpret(self, state: SAOEEpisodicState, action: int) -> float:
volume = min(state.position, state.target * self.action_values[action])
if state.cur_step + 1 >= state.num_step:
volume = state.position # execute all volumes at last
return volume
class QlibSimulator(Simulator[Order, SAOEEpisodicState, float]):
def __init__(
self,
qlib_config: QlibConfig,
time_per_step: str,
top_strategy: BaseStrategy,
inner_strategy_fn: Callable[[], BaseStrategy],
start_time: str,
end_time: str,
qlib_config: QlibConfig,
top_strategy_fn: Callable[[CommonInfrastructure, Order, TradeRange, str], BaseStrategy],
inner_executor_fn: Callable[[CommonInfrastructure], BaseExecutor],
exchange_config: ExchangeConfig,
) -> None:
@@ -83,61 +94,77 @@ class QlibSimulator(Simulator[Order, StateType, ActType]):
initial=None, # TODO
)
self.qlib_config = qlib_config
self._trade_range = TradeRangeByTime(start_time, end_time)
self._qlib_config = qlib_config
self._time_per_step = time_per_step
self._top_strategy = top_strategy
self._top_strategy_fn = top_strategy_fn
self._inner_executor_fn = inner_executor_fn
self._inner_strategy_fn = inner_strategy_fn
self._exchange_config = exchange_config
self._executor: Optional[NestedExecutor] = None
self._executor: Optional[RLNestedExecutor] = None
self._collect_data_loop: Optional[Generator] = None
self._done = False
def _reset(
self._inner_strategy = DecomposedStrategy()
def reset(
self,
instrument: str,
date_time: pd.Timestamp,
order: Order,
instrument: str = "SH600000", # TODO: Test only. Remove this default value later.
) -> None:
# TODO: init_qlib
init_qlib(self._qlib_config, instrument)
common_infra = get_common_infra(
self._exchange_config,
trade_start_time=date_time,
trade_end_time=date_time,
trade_start_time=order.start_time,
trade_end_time=order.end_time,
codes=[instrument],
)
self._executor = NestedExecutor(
self._executor = RLNestedExecutor(
time_per_step=self._time_per_step,
inner_executor=self._inner_executor_fn(common_infra),
inner_strategy=self._inner_strategy_fn(),
inner_strategy=self._inner_strategy,
track_data=True,
common_infra=common_infra,
)
self._executor.reset(start_time=date_time, end_time=date_time)
self._top_strategy.reset(level_infra=self._executor.get_level_infra())
self._collect_data_loop = self._executor.collect_data(self._top_strategy.generate_trade_decision(), level=0)
top_strategy = self._top_strategy_fn(common_infra, order, self._trade_range, instrument)
self._executor.reset(start_time=order.start_time, end_time=order.end_time)
top_strategy.reset(level_infra=self._executor.get_level_infra())
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
assert isinstance(self._collect_data_loop, Generator)
strategy = self._iter_strategy(action=None)
sample, ep_state = strategy.sample_state_pair
self._last_ep_state = ep_state
self._done = False
def step(self, action: ActType) -> None:
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
while not isinstance(strategy, DecomposedStrategy):
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
assert isinstance(strategy, DecomposedStrategy)
return strategy
def step(self, action: float) -> None:
try:
strategy = self._collect_data_loop.send(action)
while not isinstance(strategy, BaseStrategy):
strategy = self._collect_data_loop.send(action)
assert isinstance(strategy, BaseStrategy)
# TODO: do something here
strategy = self._iter_strategy(action=action)
sample, ep_state = strategy.sample_state_pair
except StopIteration:
sample, ep_state = self._inner_strategy.sample_state_pair
assert ep_state.done
self._last_ep_state = ep_state
if ep_state.done:
self._done = True
def get_state(self) -> StateType:
pass # TODO: Collect info from executor. Generate state.
def get_state(self) -> SAOEEpisodicState:
return self._last_ep_state
def done(self) -> bool:
return self._done

View File

@@ -0,0 +1,133 @@
import collections
from pathlib import Path
import pandas as pd
from qlib.backtest.decision import Order, OrderDir, TradeRange
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.utils import CommonInfrastructure
from qlib.config import QlibConfig
from qlib.contrib.strategy import TWAPStrategy
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
from qlib.rl.order_execution.from_neutrader.strategy import SingleOrderStrategy
from qlib.rl.order_execution.simulator_qlib import CategoricalActionInterpreter, ExchangeConfig, QlibSimulator
qlib_config = QlibConfig(
{
"provider_uri_day": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1d"),
"provider_uri_1min": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1min"),
"feature_root_dir": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_handler_stock"),
"feature_columns_today": [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
],
"feature_columns_yesterday": [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
)
exchange_config = ExchangeConfig(
limit_threshold=('$ask == 0', '$bid == 0'),
deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'),
volume_threshold={
'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"),
'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1')
},
open_cost=0.0005,
close_cost=0.0015,
min_cost=5.0,
trade_unit=None,
cash_limit=None,
generate_report=False,
)
def _top_strategy_fn(
common_infra: CommonInfrastructure,
order: Order,
trade_range: TradeRange,
instrument: str,
) -> SingleOrderStrategy:
return SingleOrderStrategy(common_infra, order, trade_range, instrument)
def _inner_executor_fn(common_infra: CommonInfrastructure) -> RLNestedExecutor:
return RLNestedExecutor(
time_per_step="30min",
inner_strategy=TWAPStrategy(),
inner_executor=SimulatorExecutor(
time_per_step="1min",
verbose=False,
trade_type=SimulatorExecutor.TT_SERIAL,
generate_report=False,
common_infra=common_infra,
track_data=True,
),
common_infra=common_infra,
track_data=True,
)
def test():
order_infos = [
("2019-03-04", 1078.644160270691, 1),
("2019-03-11", 32.440425872802734, 1),
("2019-03-25", 40.55053234100342, 0),
("2019-04-01", 1070.5340538024902, 0),
("2019-05-27", 300.0739393234253, 1),
("2019-06-03", 8.110106468200684, 0),
("2019-06-11", 0.9360466003417968, 0),
("2019-06-17", 794.4272003173828, 1),
("2019-06-24", 7.865615844726562, 0),
("2019-07-01", 1077.589370727539, 0),
("2021-01-04", 499.7846999168396, 1),
("2021-01-11", 14.918946266174316, 0),
("2021-01-18", 484.8657536506653, 0),
("2021-02-08", 537.0820655822754, 1),
("2021-02-18", 7.459473133087158, 0),
("2021-02-22", 7.459473133087158, 0),
("2021-03-01", 14.918946266174316, 1),
("2021-03-08", 872.7583565711975, 1),
]
orders = collections.deque([
Order(
stock_id="",
amount=info[1],
direction=OrderDir(info[2]),
start_time=pd.Timestamp(info[0]),
end_time=pd.Timestamp(info[0]),
)
for info in order_infos
])
# fmt: off
simulator = QlibSimulator(
time_per_step="1day",
start_time="9:45",
end_time="14:44",
qlib_config=qlib_config,
top_strategy_fn=_top_strategy_fn,
inner_executor_fn=_inner_executor_fn,
exchange_config=exchange_config,
)
# fmt: on
action_interpreter = CategoricalActionInterpreter(values=4)
simulator.reset(orders.popleft())
for i in range(10):
print(f"Step {i}")
ep_state = simulator.get_state()
action = action_interpreter(ep_state, 1)
simulator.step(action)
if simulator.done():
break
if __name__ == "__main__":
test()