mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Migrate to SAOEState & new qlib interpreter
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import time
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
@@ -246,7 +247,7 @@ class IdxTradeRange(TradeRange):
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
def __init__(self, start_time: str | time, end_time: str | time) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
@@ -256,13 +257,13 @@ class TradeRangeByTime(TradeRange):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : str | time
|
||||
e.g. "9:30"
|
||||
end_time : str
|
||||
end_time : str | time
|
||||
e.g. "14:30"
|
||||
"""
|
||||
self.start_time = pd.Timestamp(start_time).time()
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
|
||||
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
|
||||
@@ -472,6 +472,7 @@ class NestedExecutor(BaseExecutor):
|
||||
)
|
||||
assert isinstance(_inner_execute_result, list)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
self.inner_strategy.receive_execute_result(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
|
||||
@@ -412,7 +412,7 @@ class Indicator:
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# remove zero and negative values.
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
|
||||
|
||||
@@ -3,16 +3,6 @@ from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
seed: int = 42
|
||||
output_dir: Optional[Path] = None
|
||||
checkpoint_dir: Optional[Path] = None
|
||||
tb_log_dir: Optional[Path] = None
|
||||
debug: bool = False
|
||||
use_cuda: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from .strategy import RLStrategyBase
|
||||
|
||||
|
||||
class RLNestedExecutor(NestedExecutor):
|
||||
# RL nested executor
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
if isinstance(self.inner_strategy, RLStrategyBase):
|
||||
self.inner_strategy.post_exe_step(inner_exe_res)
|
||||
@@ -1,28 +1,14 @@
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from .highfreq_ops import DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
from qlib.contrib.ops.high_freq import DayCumsum
|
||||
from qlib.config import REG_CN
|
||||
from qlib.config import QlibConfig, REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlibConfig:
|
||||
provider_uri_day: Path
|
||||
provider_uri_1min: Path
|
||||
feature_root_dir: Path
|
||||
feature_columns_today: List[str]
|
||||
feature_columns_yesterday: List[str]
|
||||
|
||||
|
||||
_dataset = None
|
||||
|
||||
|
||||
@@ -122,7 +108,6 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
)
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D
|
||||
|
||||
if part is None:
|
||||
feature_path = config.feature_root_dir / 'feature.pkl'
|
||||
@@ -144,21 +129,3 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
config.feature_columns_yesterday,
|
||||
_internal=True
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False):
|
||||
assert _dataset is not None, 'You must call init_qlib() before doing this.'
|
||||
|
||||
if backtest:
|
||||
fields = ['$close', '$volume']
|
||||
else:
|
||||
fields = _dataset.columns_yesterday if yesterday else _dataset.columns_today
|
||||
|
||||
data = _dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0., index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip('0') for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
|
||||
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Date Using Memcache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_calendar:
|
||||
array of date.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(
|
||||
list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
def get_calendar_minute(freq='day', future=False):
|
||||
"""Load High-Freq Calendar Minute Using Memcache"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(
|
||||
list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(
|
||||
instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(
|
||||
instrument, start_index, end_index, freq)
|
||||
return series_feature.loc[series_condition]
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class IsInf(ElemOperator):
|
||||
"""IsInf Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is inf
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return np.isinf(series)
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, left=None, right=None):
|
||||
self.left = left
|
||||
self.right = right
|
||||
if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.left: self.right]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.left is None else self.left
|
||||
rr = 0 if self.right is None else abs(self.right)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
@@ -1,251 +0,0 @@
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
EPSILON = 1e-7
|
||||
|
||||
|
||||
class FlowDirection(str, Enum):
|
||||
ACQUIRE = "acquire"
|
||||
LIQUIDATE = "liquidate"
|
||||
|
||||
|
||||
def _round_time(time: int, granularity: int) -> int:
|
||||
return time - time % granularity
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseEpisodicState(abc.ABC):
|
||||
"""
|
||||
Base class for episodic states.
|
||||
"""
|
||||
|
||||
# requirements
|
||||
start_time: int
|
||||
end_time: int
|
||||
time_per_step: int
|
||||
vol_limit: Optional[float] # TODO: meaning?
|
||||
price_func: Callable[[str], np.ndarray] # TODO: meaning?
|
||||
volume_func: Callable[[], np.ndarray] # TODO: meaning?
|
||||
on_step_end: Optional[Callable[..., None]] # TODO: meaning?
|
||||
on_episode_end: Optional[Callable[..., None]] # TODO: meaning?
|
||||
asset_num: int # TODO: meaning?
|
||||
|
||||
# agent states
|
||||
num_step: int = field(init=False) # Number of steps
|
||||
cur_time: int = field(init=False) # Current time
|
||||
cur_step: int = field(init=False, default=0)
|
||||
exec_vol: Optional[np.ndarray] = field(init=False, default=None) # Execution history
|
||||
last_step_duration: int = field(init=False)
|
||||
position: float = field(init=False)
|
||||
position_history: np.ndarray = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cur_time = self.start_time
|
||||
rounded_start_time = _round_time(self.start_time, self.time_per_step)
|
||||
|
||||
# TODO: why not rounding end time?
|
||||
self.num_step = math.ceil((self.end_time - rounded_start_time) / self.time_per_step)
|
||||
|
||||
def logs(self) -> dict:
|
||||
# Base logging information shared across all subclasses.
|
||||
# You can call logs = super().logs() to get these default logs and use logs.update(...) to add other logging
|
||||
# information or override it completely to remove these logging fields.
|
||||
return {
|
||||
"logs": {
|
||||
"stop_time": self.cur_time - self.start_time,
|
||||
"stop_step": self.cur_step,
|
||||
},
|
||||
"history": {
|
||||
"volume": self.execution_history(),
|
||||
},
|
||||
}
|
||||
|
||||
def execution_history(self) -> np.ndarray:
|
||||
return np.pad(self.exec_vol, (0, self.end_time - self.start_time - len(self.exec_vol)))
|
||||
|
||||
def next_duration(self) -> int:
|
||||
left, right = self.next_interval()
|
||||
return right - left
|
||||
|
||||
def next_interval(self) -> Tuple[int, int]:
|
||||
left = _round_time(self.cur_time, self.time_per_step)
|
||||
right = left + self.time_per_step
|
||||
return max(left, self.start_time) - self.start_time, min(right, self.end_time) - self.start_time
|
||||
|
||||
@classmethod
|
||||
def get_init_field_names(cls):
|
||||
ret = []
|
||||
for f in fields(cls):
|
||||
if f.init:
|
||||
ret.append(f.name)
|
||||
return ret
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntraDaySingleAssetDataSchema:
|
||||
"""
|
||||
In the current context, raw should be a DataFrame with `datetime` as index and
|
||||
(at least) `$vwap0`, `$volume0`, `$close0` as columns.
|
||||
`processed` should be a DataFrame of 240x6, which is the same as `processed_prev`.
|
||||
"""
|
||||
|
||||
date: pd.Timestamp
|
||||
stock_id: str
|
||||
start_time: int
|
||||
end_time: int
|
||||
target: float
|
||||
flow_dir: FlowDirection
|
||||
raw: pd.DataFrame
|
||||
processed: pd.DataFrame
|
||||
processed_prev: pd.DataFrame
|
||||
|
||||
def get_price(self, type: Literal['deal', 'close'] = 'deal'):
|
||||
if type == 'deal':
|
||||
return self.raw['$price'].values
|
||||
elif type == 'close':
|
||||
return self.raw['$close0'].values
|
||||
|
||||
def get_volume(self):
|
||||
return self.raw['$volume0'].values
|
||||
|
||||
def get_processed_data(self, type: Literal['today', 'yesterday'] = 'today'):
|
||||
if type == 'today':
|
||||
return self.processed.to_numpy()
|
||||
elif type == 'yesterday':
|
||||
return self.processed_prev.to_numpy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SAOEEpisodicState(BaseEpisodicState):
|
||||
"""
|
||||
Global state of the whole time horizon.
|
||||
"""
|
||||
|
||||
# requirements
|
||||
target: float
|
||||
target_limit: float
|
||||
flow_dir: FlowDirection
|
||||
|
||||
# calculated statistics
|
||||
turnover: Optional[float] = field(init=False)
|
||||
baseline_twap: Optional[float] = field(init=False)
|
||||
baseline_vwap: Optional[float] = field(init=False)
|
||||
exec_avg_price: Optional[float] = field(init=False)
|
||||
pa_twap: Optional[float] = field(init=False)
|
||||
pa_vwap: Optional[float] = field(init=False)
|
||||
pa_close: Optional[float] = field(init=False)
|
||||
fulfill_rate: Optional[float] = field(init=False)
|
||||
|
||||
market_price: np.ndarray = field(init=False) # deal price, might be different from close
|
||||
market_close: np.ndarray = field(init=False) # close price
|
||||
market_volume: np.ndarray = field(init=False)
|
||||
|
||||
# NOTE: this is a temporary design to make it compatible with old qlib integration framework. As long as callback
|
||||
# functions are passed correctly, this field should be removed from this class.
|
||||
last_interval: Tuple[int, int] = field(default=(0, 0), init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.target >= 0
|
||||
assert self.asset_num == 1
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
self.market_volume = self.volume_func()[self.start_time : self.end_time]
|
||||
self.market_price = self.price_func("deal")[self.start_time : self.end_time]
|
||||
self.market_close = self.price_func("close")[self.start_time : self.end_time]
|
||||
self.position = self.target
|
||||
self.position_history = np.full((self.num_step + 1), np.nan)
|
||||
self.position_history[0] = self.position
|
||||
self.baseline_twap = np.mean(self.market_price)
|
||||
if self.market_volume.sum() == 0:
|
||||
self.baseline_vwap = self.baseline_twap
|
||||
else:
|
||||
self.baseline_vwap = np.average(self.market_price, weights=self.market_volume)
|
||||
|
||||
def update_stats(self) -> None:
|
||||
market_price = self.market_price[: len(self.exec_vol)]
|
||||
self.turnover = (self.exec_vol * market_price).sum()
|
||||
# exec_vol can be zero
|
||||
if np.isclose(self.exec_vol.sum(), 0):
|
||||
self.exec_avg_price = market_price[0]
|
||||
else:
|
||||
self.exec_avg_price = np.average(market_price, weights=self.exec_vol)
|
||||
|
||||
self.pa_twap = _price_advantage(self.exec_avg_price, self.baseline_twap, self.flow_dir)
|
||||
self.pa_vwap = _price_advantage(self.exec_avg_price, self.baseline_vwap, self.flow_dir)
|
||||
close_average = np.mean(self.market_close)
|
||||
self.pa_close = _price_advantage(self.exec_avg_price, close_average, self.flow_dir)
|
||||
|
||||
self.fulfill_rate = (self.target - self.position) / self.target_limit
|
||||
if abs(self.fulfill_rate - 1.0) < EPSILON:
|
||||
self.fulfill_rate = 1.0
|
||||
self.fulfill_rate *= 100
|
||||
|
||||
def logs(self) -> dict:
|
||||
logs = super().logs()
|
||||
logs.update(
|
||||
{
|
||||
"logs": {
|
||||
"turnover": self.turnover,
|
||||
"baseline_twap": self.baseline_twap,
|
||||
"baseline_vwap": self.baseline_vwap,
|
||||
"exec_avg_price": self.exec_avg_price,
|
||||
"pa_twap": self.pa_twap,
|
||||
"pa_vwap": self.pa_vwap,
|
||||
"pa_close": self.pa_close,
|
||||
"ffr": self.fulfill_rate,
|
||||
}
|
||||
}
|
||||
)
|
||||
return logs
|
||||
|
||||
def step(self, exec_vol: np.ndarray) -> None:
|
||||
l, r = self.next_interval()
|
||||
self.last_interval = (l, r)
|
||||
assert 0 <= l < r
|
||||
self.last_step_duration = len(exec_vol)
|
||||
self.position -= exec_vol.sum()
|
||||
assert (
|
||||
self.position > -EPSILON and (exec_vol > -EPSILON).all(),
|
||||
f"Execution volume is invalid: {exec_vol} (position = {self.position})",
|
||||
)
|
||||
self.cur_step += 1
|
||||
self.position_history[self.cur_step] = self.position
|
||||
self.cur_time += self.last_step_duration
|
||||
if self.cur_step == self.num_step: # Should reach the end of episode
|
||||
assert self.cur_time == self.end_time
|
||||
self.exec_vol = exec_vol if self.exec_vol is None else np.concatenate((self.exec_vol, exec_vol))
|
||||
|
||||
if self.on_step_end is not None:
|
||||
self.on_step_end(l, r, self)
|
||||
if self.done:
|
||||
self.update_stats()
|
||||
if self.on_episode_end is not None:
|
||||
self.on_episode_end(self)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.position < EPSILON or self.cur_step == self.num_step
|
||||
|
||||
|
||||
def _price_advantage(exec_price: float, baseline_price: float, flow: FlowDirection) -> float:
|
||||
if baseline_price == 0:
|
||||
return 0.0
|
||||
if flow == FlowDirection.ACQUIRE:
|
||||
return (1 - exec_price / baseline_price) * 10000
|
||||
else:
|
||||
return (exec_price / baseline_price - 1) * 10000
|
||||
@@ -1,162 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.rl.order_execution.from_neutrader.feature import fetch_features
|
||||
from qlib.rl.order_execution.from_neutrader.state import FlowDirection, IntraDaySingleAssetDataSchema, SAOEEpisodicState
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain neutrader states taking qlib trade decisions as input.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.send_execute_result(execute_result) # in step
|
||||
# do something here
|
||||
maintainer.generate_orders(self.get_data_cal_avail_range(rtype='step'), exec_vols)
|
||||
|
||||
The states can be accessed via ``maintianer.states`` and ``maintainer.samples``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_per_step: int,
|
||||
date: pd.Timestamp,
|
||||
full_trade_range: Tuple[int, int],
|
||||
current_step: int,
|
||||
outer_trade_decision: TradeDecisionWO,
|
||||
trade_exchange: Exchange,
|
||||
) -> None:
|
||||
# The parameters look very ad-hoc right now
|
||||
self.states: Dict[Tuple[str, OrderDir], SAOEEpisodicState] = {} # explicitly make it ordered
|
||||
self.samples: Dict[Tuple[str, OrderDir], IntraDaySingleAssetDataSchema] = {}
|
||||
self.time_per_step: int = time_per_step
|
||||
self.start_time, self.end_time = full_trade_range
|
||||
self.end_time += 1 # plus 1 to align with the semantics in neutrader
|
||||
self.date: pd.Timestamp = date
|
||||
self.last_step_length: int = -1
|
||||
self.last_step_range: Optional[Tuple[int, int]] = None
|
||||
|
||||
self.order_list: List[Order] = outer_trade_decision.order_list
|
||||
self.trade_exchange: Exchange = trade_exchange
|
||||
|
||||
self.num_step = (
|
||||
self.end_time - (self.start_time - self.start_time % self.time_per_step) - 1
|
||||
) // self.time_per_step + 1
|
||||
|
||||
for order in self.order_list:
|
||||
sample = self._fetch_sample_data(order)
|
||||
state = self._create_single_ep_state(sample, current_step)
|
||||
self.samples[order.stock_id, order.direction] = sample
|
||||
self.states[order.stock_id, order.direction] = state
|
||||
|
||||
def _fetch_sample_data(self, order: Order) -> IntraDaySingleAssetDataSchema:
|
||||
start_time = self.date.replace(hour=0, minute=0, second=0)
|
||||
end_time = self.date.replace(hour=23, minute=59, second=59)
|
||||
deal_price = self.trade_exchange.get_deal_price(
|
||||
stock_id=order.stock_id, start_time=start_time, end_time=end_time, direction=order.direction, method=None,
|
||||
)
|
||||
backtest_data = fetch_features(order.stock_id, self.date, backtest=True)
|
||||
# HACK: close means deal price here. The logic is implemented in qlib.
|
||||
backtest_data["$close"] = deal_price.to_series().to_numpy()
|
||||
feature_today = fetch_features(order.stock_id, self.date)
|
||||
feature_yesterday = fetch_features(order.stock_id, self.date, yesterday=True)
|
||||
return IntraDaySingleAssetDataSchema(
|
||||
date=self.date.date(),
|
||||
stock_id=order.stock_id,
|
||||
start_time=self.start_time,
|
||||
end_time=self.end_time,
|
||||
target=max(order.amount, 0.0), # prevent target to go to -eps
|
||||
flow_dir=FlowDirection.LIQUIDATE if order.direction == 0 else FlowDirection.ACQUIRE,
|
||||
raw=backtest_data,
|
||||
processed=feature_today,
|
||||
processed_prev=feature_yesterday,
|
||||
)
|
||||
|
||||
def _create_single_ep_state(self, sample: IntraDaySingleAssetDataSchema, cur_step: int) -> SAOEEpisodicState:
|
||||
market_price = sample.raw["$close"].values
|
||||
market_vol = sample.raw["$volume"].values
|
||||
target = sample.target
|
||||
|
||||
# NOTE: Previously, market_price and market_vol are passed into the state initialization directly. Therefore,
|
||||
# the segment of market_price and market_vol are used instead of the lambda function here using the whole price
|
||||
# and vol data.
|
||||
# This refactoring is ONLY EQUIVALENT WHEN start_time/end_time passed into state is equal to
|
||||
# sample.start_time/end_time.
|
||||
# If one can confirm that these two are always the same, delete this note, please.
|
||||
state = SAOEEpisodicState(
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
self.time_per_step,
|
||||
None,
|
||||
lambda x: market_price,
|
||||
lambda: market_vol,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
target,
|
||||
target,
|
||||
sample.flow_dir,
|
||||
)
|
||||
state.cur_step = cur_step
|
||||
assert state.cur_step == 0
|
||||
return state
|
||||
|
||||
def _update_single_ep_state(
|
||||
self, state: SAOEEpisodicState, execute_result: List[Order], length: Optional[int] = None
|
||||
) -> None:
|
||||
if length is not None:
|
||||
exec_vol = np.zeros(length)
|
||||
for order, _, __, ___ in execute_result:
|
||||
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
|
||||
exec_vol[idx - self.last_step_range[0]] = order.deal_amount
|
||||
else:
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
|
||||
# sometimes exec_vol gets too large due to the rounding in exchange
|
||||
# scale the execution volume so that position won't go below 0
|
||||
# actually this case is very rare
|
||||
if exec_vol.sum() > state.position and exec_vol.sum() > 0:
|
||||
assert exec_vol.sum() < state.position + 1, f"{exec_vol} too large for {state}"
|
||||
exec_vol *= state.position / (exec_vol.sum())
|
||||
|
||||
state.step(exec_vol)
|
||||
|
||||
def create_sub_order(self, exec_vol: float, original_order: Order) -> Order:
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
return oh.create(original_order.stock_id, exec_vol, original_order.direction)
|
||||
|
||||
def send_execute_result(self, execute_result: Optional[List[Any]]) -> None:
|
||||
if self.last_step_length < 0:
|
||||
assert not execute_result
|
||||
return
|
||||
orders = defaultdict(list)
|
||||
if execute_result is not None:
|
||||
for e in execute_result:
|
||||
orders[e[0].stock_id, e[0].direction].append(e)
|
||||
for (stock_id, direction), state in self.states.items():
|
||||
self._update_single_ep_state(state, orders[stock_id, direction], self.last_step_length)
|
||||
|
||||
def generate_orders(self, step_trade_range: Tuple[int, int], exec_vols: List[float]) -> List[Order]:
|
||||
order_list = []
|
||||
|
||||
assert len(exec_vols) == len(self.order_list)
|
||||
for v, o in zip(exec_vols, self.order_list):
|
||||
if v > 0:
|
||||
order_list.append(self.create_sub_order(v, o))
|
||||
|
||||
step_start_time, step_end_time = step_trade_range # inclusive
|
||||
step_end_time += 1
|
||||
|
||||
self.last_step_length = step_end_time - step_start_time
|
||||
self.last_step_range = (step_start_time, step_end_time)
|
||||
|
||||
return order_list
|
||||
@@ -1,64 +1,39 @@
|
||||
from abc import ABCMeta
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.rl.order_execution.from_neutrader.state import IntraDaySingleAssetDataSchema, SAOEEpisodicState
|
||||
from qlib.rl.order_execution.from_neutrader.state_maintainer import StateMaintainer
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class RLStrategyBase(BaseStrategy, metaclass=ABCMeta):
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
"""
|
||||
post process for each step of strategy this is design for RL Strategy,
|
||||
which require to update the policy state after each step
|
||||
|
||||
NOTE: it is strongly coupled with RLNestedExecutor;
|
||||
"""
|
||||
raise NotImplementedError("Please implement the `post_exe_step` method")
|
||||
|
||||
|
||||
class DecomposedStrategy(RLStrategyBase):
|
||||
def __init__(self):
|
||||
class DecomposedStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
super(DecomposedStrategy, self).__init__()
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
time_per_step = int(pd.Timedelta(self.trade_calendar.get_freq()) / pd.Timedelta("1min"))
|
||||
if outer_trade_decision is not None:
|
||||
self.maintainer = StateMaintainer(
|
||||
time_per_step,
|
||||
self.trade_calendar.get_all_time()[0],
|
||||
self.get_data_cal_avail_range(),
|
||||
self.trade_calendar.get_trade_step(),
|
||||
outer_trade_decision,
|
||||
self.trade_exchange,
|
||||
)
|
||||
self.execute_order: Optional[Order] = None
|
||||
self.execute_result: List[Tuple[Order, float, float, float]] = []
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
self.execute_order = order
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result):
|
||||
self.maintainer.send_execute_result(execute_result)
|
||||
def receive_execute_result(self, execute_result: list) -> None:
|
||||
self.execute_result = execute_result
|
||||
|
||||
@property
|
||||
def sample_state_pair(self) -> Tuple[IntraDaySingleAssetDataSchema, SAOEEpisodicState]:
|
||||
assert len(self.maintainer.samples) == len(self.maintainer.states) == 1
|
||||
return (
|
||||
list(self.maintainer.samples.values())[0],
|
||||
list(self.maintainer.states.values())[0],
|
||||
)
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
# get a decision from the outmost loop
|
||||
exec_vol = yield self
|
||||
|
||||
return TradeDecisionWO(
|
||||
self.maintainer.generate_orders(self.get_data_cal_avail_range(rtype="step"), [exec_vol]), self
|
||||
)
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
@@ -87,5 +62,4 @@ class SingleOrderStrategy(BaseStrategy):
|
||||
direction=Order.parse_dir(self._order.direction),
|
||||
)
|
||||
]
|
||||
trade_decision = TradeDecisionWO(order_list, self, self._trade_range)
|
||||
return trade_decision
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
|
||||
@@ -2,32 +2,30 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
import copy
|
||||
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Generator, List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym.vector.utils import spaces
|
||||
from qlib.rl.order_execution.from_neutrader.feature import init_qlib
|
||||
|
||||
from qlib.backtest import get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import Order, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.config import QlibConfig
|
||||
from qlib.rl.interpreter import ActionInterpreter
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
|
||||
from qlib.rl.order_execution.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.from_neutrader.state import SAOEEpisodicState
|
||||
from qlib.rl.order_execution.from_neutrader.strategy import DecomposedStrategy
|
||||
from qlib.rl.order_execution.from_neutrader.strategy import DecomposedStrategy, SingleOrderStrategy
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, SAOEMetrics, SAOEState, _float_or_ndarray
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
trade_date: pd.Timestamp,
|
||||
codes: List[str],
|
||||
cash_limit: Optional[float] = None,
|
||||
) -> CommonInfrastructure:
|
||||
@@ -48,14 +46,14 @@ def get_common_infra(
|
||||
|
||||
exchange = get_exchange(
|
||||
codes=codes,
|
||||
freq='1min',
|
||||
freq="1min",
|
||||
limit_threshold=config.limit_threshold,
|
||||
deal_price=config.deal_price,
|
||||
open_cost=config.open_cost,
|
||||
close_cost=config.close_cost,
|
||||
min_cost=config.min_cost if config.trade_unit is not None else 0,
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold
|
||||
)
|
||||
@@ -63,114 +61,253 @@ def get_common_infra(
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
|
||||
|
||||
class CategoricalActionInterpreter(ActionInterpreter[SAOEEpisodicState, int, float]):
|
||||
def __init__(self, values: Union[int, List[float]]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Discrete:
|
||||
return spaces.Discrete(len(self.action_values))
|
||||
|
||||
def interpret(self, state: SAOEEpisodicState, action: int) -> float:
|
||||
volume = min(state.position, state.target * self.action_values[action])
|
||||
if state.cur_step + 1 >= state.num_step:
|
||||
volume = state.position # execute all volumes at last
|
||||
return volume
|
||||
def _convert_tick_str_to_int(time_per_step: str) -> int:
|
||||
d = {
|
||||
"30min": 30,
|
||||
}
|
||||
return d[time_per_step]
|
||||
|
||||
|
||||
class QlibSimulator(Simulator[Order, Tuple[SAOEEpisodicState, dict], float]):
|
||||
def _get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
def _get_minutes(start_time: pd.Timestamp, end_time: pd.Timestamp) -> List[pd.Timestamp]:
|
||||
minutes = []
|
||||
t = start_time
|
||||
while t <= end_time:
|
||||
minutes.append(t)
|
||||
t += pd.Timedelta("1min")
|
||||
return minutes
|
||||
|
||||
|
||||
def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
# dataframe.append is deprecated
|
||||
other_df = pd.DataFrame(other).set_index("datetime")
|
||||
other_df.index.name = "datetime"
|
||||
|
||||
res = pd.concat([df, other_df], axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def _price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
else:
|
||||
return np.zeros_like(exec_price)
|
||||
if direction == OrderDir.BUY:
|
||||
res = (1 - exec_price / baseline_price) * 10000
|
||||
elif direction == OrderDir.SELL:
|
||||
res = (exec_price / baseline_price - 1) * 10000
|
||||
else:
|
||||
raise ValueError(f"Unexpected order direction: {direction}")
|
||||
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
def __init__(self, order: Order, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
|
||||
super(StateMaintainer, self).__init__()
|
||||
|
||||
self.position = order.amount
|
||||
self._order = order
|
||||
self._tick_index = tick_index
|
||||
self._twap_price = twap_price
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
# NOTE: can empty dataframe contain index?
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
def update(self, inner_executor: BaseExecutor, inner_strategy: DecomposedStrategy) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
if len(execute_result) > 0:
|
||||
exchange = inner_executor.trade_exchange
|
||||
minutes = _get_minutes(execute_result[0][0].start_time, execute_result[-1][0].start_time)
|
||||
market_price = np.array([
|
||||
exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction)
|
||||
for t in minutes
|
||||
])
|
||||
market_volume = np.array([exchange.get_volume(execute_order.stock_id, t, t) for t in minutes])
|
||||
|
||||
datetime_list = _get_ticks_slice(
|
||||
self._tick_index,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
include_end=True
|
||||
)
|
||||
else:
|
||||
market_price = np.array([])
|
||||
market_volume = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = _dataframe_append(
|
||||
self.history_exec,
|
||||
SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=self._order.stock_id,
|
||||
datetime=datetime_list,
|
||||
direction=self._order.direction,
|
||||
market_volume=market_volume,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=ticks_position,
|
||||
ffr=exec_vol / self._order.amount,
|
||||
pa=_price_advantage(market_price, self._twap_price, self._order.direction),
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = _dataframe_append(
|
||||
self.history_steps,
|
||||
[self._metrics_collect(
|
||||
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
|
||||
)],
|
||||
)
|
||||
|
||||
def _metrics_collect(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / order.amount),
|
||||
pa=_price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
|
||||
class QlibSimulator(Simulator[Order, SAOEState, float]):
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
time_per_step: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
qlib_config: QlibConfig,
|
||||
top_strategy_fn: Callable[[CommonInfrastructure, Order, TradeRange, str], BaseStrategy],
|
||||
inner_executor_fn: Callable[[CommonInfrastructure], BaseExecutor],
|
||||
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
) -> None:
|
||||
super(QlibSimulator, self).__init__(
|
||||
initial=None, # TODO
|
||||
)
|
||||
|
||||
self._trade_range = TradeRangeByTime(start_time, end_time)
|
||||
assert order.start_time.date() == order.end_time.date()
|
||||
|
||||
self._order = order
|
||||
self._order_date = pd.Timestamp(order.start_time.date())
|
||||
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
|
||||
self._qlib_config = qlib_config
|
||||
self._time_per_step = time_per_step
|
||||
self._top_strategy_fn = top_strategy_fn
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._executor: Optional[RLNestedExecutor] = None
|
||||
self._time_per_step = time_per_step
|
||||
self._ticks_per_step = _convert_tick_str_to_int(time_per_step)
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
|
||||
self._done = False
|
||||
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
def reset(
|
||||
self,
|
||||
order: Order,
|
||||
instrument: str = "SH600000", # TODO: Test only. Remove this default value later.
|
||||
) -> None:
|
||||
self.reset(self._order)
|
||||
|
||||
def reset(self, order: Order) -> None:
|
||||
instrument = order.stock_id
|
||||
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_start_time=order.start_time,
|
||||
trade_end_time=order.end_time,
|
||||
trade_date=pd.Timestamp(self._order_date),
|
||||
codes=[instrument],
|
||||
)
|
||||
|
||||
self._executor = RLNestedExecutor(
|
||||
time_per_step=self._time_per_step,
|
||||
inner_executor=self._inner_executor_fn(common_infra),
|
||||
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
|
||||
self._executor = NestedExecutor(
|
||||
time_per_step="1day",
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
)
|
||||
|
||||
top_strategy = self._top_strategy_fn(common_infra, order, self._trade_range, instrument)
|
||||
exchange = self._inner_executor.trade_exchange
|
||||
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
|
||||
self._ticks_for_order = _get_ticks_slice(self._ticks_index, self._order.start_time, self._order.end_time)
|
||||
|
||||
self._executor.reset(start_time=order.start_time, end_time=order.end_time)
|
||||
twap_price = exchange.get_deal_price(
|
||||
order.stock_id,
|
||||
pd.Timestamp(self._ticks_for_order[0]),
|
||||
pd.Timestamp(self._ticks_for_order[1]),
|
||||
direction=order.direction,
|
||||
)
|
||||
|
||||
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
|
||||
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
strategy = self._iter_strategy(action=None)
|
||||
sample, ep_state = strategy.sample_state_pair
|
||||
self._last_ep_state = ep_state
|
||||
self._last_info = self._collect_info(ep_state)
|
||||
|
||||
self._iter_strategy(action=None)
|
||||
self._done = False
|
||||
|
||||
def _collect_info(self, ep_state: SAOEEpisodicState) -> dict:
|
||||
info = {
|
||||
"category": ep_state.flow_dir.value,
|
||||
# "reward": rew_info, # TODO: ignore for now
|
||||
}
|
||||
if ep_state.done:
|
||||
# info["index"] = {"stock_id": sample.stock_id, "date": sample.date} # TODO: ignore for now
|
||||
# info["history"] = {"action": self.action_history} # TODO: ignore for now
|
||||
info.update(ep_state.logs())
|
||||
|
||||
try:
|
||||
# done but loop is not exhausted
|
||||
# exhaust the loop manually
|
||||
while True:
|
||||
self._collect_data_loop.send(0.)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
info["qlib"] = {}
|
||||
for key, val in list(
|
||||
self._executor.trade_account.get_trade_indicator().order_indicator_his.values()
|
||||
)[0].to_series().items():
|
||||
info["qlib"][key] = val.item()
|
||||
|
||||
return info
|
||||
self._maintainer = StateMaintainer(
|
||||
order=self._order,
|
||||
tick_index=self._ticks_index,
|
||||
twap_price=twap_price,
|
||||
)
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
@@ -181,20 +318,28 @@ class QlibSimulator(Simulator[Order, Tuple[SAOEEpisodicState, dict], float]):
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
try:
|
||||
strategy = self._iter_strategy(action=action)
|
||||
sample, ep_state = strategy.sample_state_pair
|
||||
self._iter_strategy(action=action)
|
||||
except StopIteration:
|
||||
sample, ep_state = self._inner_strategy.sample_state_pair
|
||||
assert ep_state.done
|
||||
|
||||
self._last_ep_state = ep_state
|
||||
self._last_info = self._collect_info(ep_state)
|
||||
|
||||
if ep_state.done:
|
||||
self._done = True
|
||||
|
||||
def get_state(self) -> Tuple[SAOEEpisodicState, dict]:
|
||||
return self._last_ep_state, self._last_info
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self._order,
|
||||
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
|
||||
position=self._maintainer.position,
|
||||
history_exec=self._maintainer.history_exec,
|
||||
history_steps=self._maintainer.history_steps,
|
||||
metrics=self._maintainer.metrics,
|
||||
backtest_data=None,
|
||||
ticks_per_step=self._ticks_per_step,
|
||||
ticks_index=self._ticks_index,
|
||||
ticks_for_order=self._ticks_for_order,
|
||||
)
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@@ -39,34 +39,34 @@ class SAOEMetrics(TypedDict):
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: float
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: float
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: float
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: float
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: float
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: float
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: float
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: float
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: float
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: float
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
@@ -231,9 +231,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
direction=self.order.direction,
|
||||
market_volume=self.market_vol,
|
||||
market_price=self.market_price,
|
||||
amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
inner_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
deal_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=self.market_price,
|
||||
trade_value=self.market_price * exec_vol,
|
||||
position=ticks_position,
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import collections
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeRange
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.config import QlibConfig
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor
|
||||
from qlib.rl.order_execution.from_neutrader.strategy import SingleOrderStrategy
|
||||
from qlib.rl.order_execution.simulator_qlib import CategoricalActionInterpreter, ExchangeConfig, QlibSimulator
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, QlibSimulator
|
||||
|
||||
# fmt: off
|
||||
qlib_config = QlibConfig(
|
||||
{
|
||||
"provider_uri_day": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1d"),
|
||||
@@ -27,6 +26,7 @@ qlib_config = QlibConfig(
|
||||
],
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=('$ask == 0', '$bid == 0'),
|
||||
@@ -44,18 +44,9 @@ exchange_config = ExchangeConfig(
|
||||
)
|
||||
|
||||
|
||||
def _top_strategy_fn(
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> SingleOrderStrategy:
|
||||
return SingleOrderStrategy(common_infra, order, trade_range, instrument)
|
||||
|
||||
|
||||
def _inner_executor_fn(common_infra: CommonInfrastructure) -> RLNestedExecutor:
|
||||
return RLNestedExecutor(
|
||||
time_per_step="30min",
|
||||
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
|
||||
return NestedExecutor(
|
||||
time_per_step=time_per_step,
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
@@ -71,67 +62,36 @@ def _inner_executor_fn(common_infra: CommonInfrastructure) -> RLNestedExecutor:
|
||||
|
||||
|
||||
def test():
|
||||
order_infos = [
|
||||
("2019-03-04", 1078.644160270691, 1),
|
||||
("2019-03-11", 32.440425872802734, 1),
|
||||
("2019-03-25", 40.55053234100342, 0),
|
||||
("2019-04-01", 1070.5340538024902, 0),
|
||||
("2019-05-27", 300.0739393234253, 1),
|
||||
("2019-06-03", 8.110106468200684, 0),
|
||||
("2019-06-11", 0.9360466003417968, 0),
|
||||
("2019-06-17", 794.4272003173828, 1),
|
||||
("2019-06-24", 7.865615844726562, 0),
|
||||
("2019-07-01", 1077.589370727539, 0),
|
||||
("2021-01-04", 499.7846999168396, 1),
|
||||
("2021-01-11", 14.918946266174316, 0),
|
||||
("2021-01-18", 484.8657536506653, 0),
|
||||
("2021-02-08", 537.0820655822754, 1),
|
||||
("2021-02-18", 7.459473133087158, 0),
|
||||
("2021-02-22", 7.459473133087158, 0),
|
||||
("2021-03-01", 14.918946266174316, 1),
|
||||
("2021-03-08", 872.7583565711975, 1),
|
||||
order = Order(
|
||||
stock_id="SH600000",
|
||||
amount=1078.644160270691,
|
||||
direction=OrderDir(1),
|
||||
start_time=pd.Timestamp("2019-03-04 09:45:00"),
|
||||
end_time=pd.Timestamp("2019-03-04 14:44:00"),
|
||||
)
|
||||
|
||||
]
|
||||
orders = collections.deque([
|
||||
Order(
|
||||
stock_id="",
|
||||
amount=info[1],
|
||||
direction=OrderDir(info[2]),
|
||||
start_time=pd.Timestamp(info[0]),
|
||||
end_time=pd.Timestamp(info[0]),
|
||||
)
|
||||
for info in order_infos
|
||||
])
|
||||
|
||||
# fmt: off
|
||||
simulator = QlibSimulator(
|
||||
time_per_step="1day",
|
||||
start_time="9:45",
|
||||
end_time="14:44",
|
||||
order=order,
|
||||
time_per_step="30min",
|
||||
qlib_config=qlib_config,
|
||||
top_strategy_fn=_top_strategy_fn,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
action_interpreter = CategoricalActionInterpreter(values=4)
|
||||
|
||||
simulator.reset(orders.popleft())
|
||||
interpreter_action = CategoricalActionInterpreter(values=4)
|
||||
|
||||
state = simulator.get_state()
|
||||
print(state.position)
|
||||
for i in range(10):
|
||||
print(f"Step {i}")
|
||||
ep_state, info = simulator.get_state()
|
||||
action = action_interpreter(ep_state, 1)
|
||||
simulator.step(interpreter_action(state, 1))
|
||||
|
||||
state = simulator.get_state()
|
||||
print(state.position)
|
||||
|
||||
simulator.step(action)
|
||||
if simulator.done():
|
||||
break
|
||||
|
||||
ep_state, info = simulator.get_state()
|
||||
print(info["logs"])
|
||||
print(info["qlib"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
@@ -204,6 +204,9 @@ class BaseStrategy:
|
||||
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
def receive_execute_result(self, execute_result: list) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy, metaclass=ABCMeta):
|
||||
"""RL-based strategy"""
|
||||
|
||||
@@ -269,7 +269,7 @@ class LocIndexer:
|
||||
if isinstance(_indexing, IndexData):
|
||||
_indexing = _indexing.data
|
||||
assert _indexing.ndim == 1
|
||||
if _indexing.dtype != np.bool:
|
||||
if _indexing.dtype != bool:
|
||||
_indexing = np.array(list(index.index(i) for i in _indexing))
|
||||
else:
|
||||
_indexing = index.index(_indexing)
|
||||
@@ -429,7 +429,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
|
||||
# The code below could be simpler like methods in __getattribute__
|
||||
def __invert__(self):
|
||||
return self.__class__(~self.data.astype(np.bool), *self.indices)
|
||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN."""
|
||||
|
||||
Reference in New Issue
Block a user