mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Qlib simulator refinement (redo of PR 1244) (#1262)
* Use dict-like configuration * Rename from_neutrader to integration * SAOE strategy * Optimize file structure * Optimize code * Format code * create_state_maintainer_recursive * Remove explicit time_per_step * CI test passed * Resolve PR comments * Pass all CI * Minor test issue * Refine SAOE adapter logic * Minor bugfix * Cherry pick updates * Resolve PR comments * CI issues * Refine adapter & saoe_data logic * Resolve PR comments * Resolve PR comments * Rename ONE_SEC to EPS_T; complete backtest loop * CI issue * Resolve Yuge's PR comments
This commit is contained in:
@@ -345,4 +345,4 @@ def format_decisions(
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order", "backtest"]
|
||||
__all__ = ["Order", "backtest", "get_strategy_executor"]
|
||||
|
||||
@@ -83,7 +83,9 @@ def collect_data_loop(
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)
|
||||
trade_strategy.post_exe_step(_execute_result)
|
||||
bar.update(1)
|
||||
trade_strategy.post_upper_level_exe_step()
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
@@ -135,6 +135,21 @@ class Order:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@property
|
||||
def key_by_day(self) -> tuple:
|
||||
"""A hashable & unique key to identify this order, under the granularity in day."""
|
||||
return self.stock_id, self.date, self.direction
|
||||
|
||||
@property
|
||||
def key(self) -> tuple:
|
||||
"""A hashable & unique key to identify this order."""
|
||||
return self.stock_id, self.start_time, self.end_time, self.direction
|
||||
|
||||
@property
|
||||
def date(self) -> pd.Timestamp:
|
||||
"""Date of the order."""
|
||||
return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0))
|
||||
|
||||
|
||||
class OrderHelper:
|
||||
"""
|
||||
|
||||
@@ -114,7 +114,7 @@ class BaseExecutor:
|
||||
self.track_data = track_data
|
||||
self._trade_exchange = trade_exchange
|
||||
self.level_infra = LevelInfrastructure()
|
||||
self.level_infra.reset_infra(common_infra=common_infra)
|
||||
self.level_infra.reset_infra(common_infra=common_infra, executor=self)
|
||||
self._settle_type = settle_type
|
||||
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
|
||||
if common_infra is None:
|
||||
@@ -134,6 +134,8 @@ class BaseExecutor:
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
self.level_infra.reset_infra(common_infra=self.common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
@@ -256,6 +258,7 @@ class BaseExecutor:
|
||||
object
|
||||
trade decision
|
||||
"""
|
||||
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
|
||||
@@ -296,6 +299,7 @@ class BaseExecutor:
|
||||
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": res})
|
||||
|
||||
return res
|
||||
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
@@ -396,7 +400,7 @@ class NestedExecutor(BaseExecutor):
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outer decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def _collect_data(
|
||||
@@ -473,6 +477,9 @@ class NestedExecutor(BaseExecutor):
|
||||
# do nothing and just step forward
|
||||
sub_cal.step()
|
||||
|
||||
# Let inner strategy know that the outer level execution is done.
|
||||
self.inner_strategy.post_upper_level_exe_step()
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
|
||||
from typing import Any, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -184,8 +183,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||
left = np.searchsorted(self._calendar, start_time, side="right") - 1
|
||||
right = np.searchsorted(self._calendar, end_time, side="right") - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
@@ -248,7 +247,7 @@ class LevelInfrastructure(BaseInfrastructure):
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra"}
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra", "executor"}
|
||||
|
||||
def reset_cal(
|
||||
self,
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# REGION CONST
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
REG_TW = "tw"
|
||||
@@ -10,4 +15,8 @@ REG_TW = "tw"
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
INF = int(1e18)
|
||||
ONE_DAY = pd.Timedelta("1day")
|
||||
ONE_MIN = pd.Timedelta("1min")
|
||||
EPS_T = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
float_or_ndarray = TypeVar("float_or_ndarray", float, np.ndarray)
|
||||
|
||||
@@ -615,4 +615,4 @@ class TSDatasetH(DatasetH):
|
||||
return tsds
|
||||
|
||||
|
||||
__all__ = ["Optional"]
|
||||
__all__ = ["Optional", "Dataset", "DatasetH"]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
|
||||
@@ -3,21 +3,33 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from .pickle_styled import IntradayBacktestData
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import ONE_DAY, EPS_T
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
from qlib.utils.index_data import IndexData
|
||||
from .pickle_styled import BaseIntradayBacktestData
|
||||
|
||||
|
||||
class QlibIntradayBacktestData(IntradayBacktestData):
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data for Qlib simulator"""
|
||||
|
||||
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
|
||||
super(QlibIntradayBacktestData, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
exchange: Exchange,
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
ticks_for_order: pd.DatetimeIndex,
|
||||
) -> None:
|
||||
self._order = order
|
||||
self._exchange = exchange
|
||||
self._start_time = start_time
|
||||
self._end_time = end_time
|
||||
self._start_time = ticks_for_order[0]
|
||||
self._end_time = ticks_for_order[-1]
|
||||
self.ticks_index = ticks_index
|
||||
self.ticks_for_order = ticks_for_order
|
||||
|
||||
self._deal_price = cast(
|
||||
pd.Series,
|
||||
@@ -56,3 +68,43 @@ class QlibIntradayBacktestData(IntradayBacktestData):
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda order, _, __: order.key_by_day,
|
||||
)
|
||||
def load_qlib_backtest_data(
|
||||
order: Order,
|
||||
trade_exchange: Exchange,
|
||||
trade_range: TradeRange,
|
||||
) -> IntradayBacktestData:
|
||||
data = cast(
|
||||
IndexData,
|
||||
trade_exchange.get_deal_price(
|
||||
stock_id=order.stock_id,
|
||||
start_time=order.date,
|
||||
end_time=order.date + ONE_DAY - EPS_T,
|
||||
direction=order.direction,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
ticks_index = pd.DatetimeIndex(data.index)
|
||||
if isinstance(trade_range, TradeRangeByTime):
|
||||
ticks_for_order = get_ticks_slice(
|
||||
ticks_index,
|
||||
trade_range.start_time,
|
||||
trade_range.end_time,
|
||||
include_end=True,
|
||||
)
|
||||
else:
|
||||
ticks_for_order = None # FIXME: implement this logic
|
||||
|
||||
backtest_data = IntradayBacktestData(
|
||||
order=order,
|
||||
exchange=trade_exchange,
|
||||
ticks_index=ticks_index,
|
||||
ticks_for_order=ticks_for_order,
|
||||
)
|
||||
return backtest_data
|
||||
|
||||
@@ -86,7 +86,7 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
return pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
|
||||
|
||||
class IntradayBacktestData:
|
||||
class BaseIntradayBacktestData:
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
@@ -115,7 +115,7 @@ class IntradayBacktestData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(IntradayBacktestData):
|
||||
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data for simple simulator"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config.
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.0
|
||||
trade_unit: Optional[float] = 100.0
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
@@ -1,109 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import collections
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, pool_size: int = 200):
|
||||
self.pool_size = pool_size
|
||||
self.contents: dict = {}
|
||||
self.keys: collections.deque = collections.deque()
|
||||
|
||||
def put(self, key, item):
|
||||
if self.has(key):
|
||||
self.keys.remove(key)
|
||||
self.keys.append(key)
|
||||
self.contents[key] = item
|
||||
while len(self.contents) > self.pool_size:
|
||||
self.contents.pop(self.keys.popleft())
|
||||
|
||||
def get(self, key):
|
||||
return self.contents[key]
|
||||
|
||||
def has(self, key):
|
||||
return key in self.contents
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
# TODO: We might have the chance to merge them.
|
||||
self.feature_cache = LRUCache()
|
||||
self.backtest_cache = LRUCache()
|
||||
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
|
||||
if backtest:
|
||||
dataset = self.backtest_dataset
|
||||
cache = self.backtest_cache
|
||||
else:
|
||||
dataset = self.feature_dataset
|
||||
cache = self.feature_cache
|
||||
|
||||
if cache.has((start_time, end_time, stock_id)):
|
||||
return cache.get((start_time, end_time, stock_id))
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
cache.put((start_time, end_time, stock_id), data)
|
||||
return data
|
||||
|
||||
|
||||
def init_qlib(config: dict, part: Optional[str] = None) -> None:
|
||||
provider_uri_map = {
|
||||
"day": config["provider_uri_day"].as_posix(),
|
||||
"1min": config["provider_uri_1min"].as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
163
qlib/rl/order_execution/integration.py
Normal file
163
qlib/rl/order_execution/integration.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
TODO: This file is used to integrate NeuTrader with Qlib to run the existing projects.
|
||||
TODO: The implementation here is kind of adhoc. It is better to design a more uniformed & general implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
dataset = None
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
)
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = self.backtest_dataset if backtest else self.feature_dataset
|
||||
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
|
||||
def init_qlib(qlib_config: dict, part: str = None) -> None:
|
||||
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_config:
|
||||
Qlib configuration.
|
||||
|
||||
Example::
|
||||
|
||||
{
|
||||
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
|
||||
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
|
||||
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
|
||||
"feature_columns_today": [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
|
||||
],
|
||||
"feature_columns_yesterday": [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
part
|
||||
Identifying which part (stock / date) to load.
|
||||
"""
|
||||
|
||||
global dataset # pylint: disable=W0603
|
||||
|
||||
def _convert_to_path(path: str | Path) -> Path:
|
||||
return path if isinstance(path, Path) else Path(path)
|
||||
|
||||
provider_uri_map = {
|
||||
"day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(),
|
||||
"1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
if part == "skip":
|
||||
return
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D # noqa pylint: disable=C0415,W0611
|
||||
|
||||
if part is None:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
|
||||
else:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
|
||||
|
||||
with feature_path.open("rb") as f:
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open("rb") as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
dataset = DataWrapper(
|
||||
feature_dataset,
|
||||
backtest_dataset,
|
||||
qlib_config["feature_columns_today"],
|
||||
qlib_config["feature_columns_yesterday"],
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
|
||||
assert dataset is not None, "You must call init_qlib() before doing this."
|
||||
|
||||
if backtest:
|
||||
fields = ["$close", "$volume"]
|
||||
else:
|
||||
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
|
||||
|
||||
data = dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
@@ -14,15 +14,15 @@ from gym import spaces
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution.state import SAOEState
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
|
||||
__all__ = [
|
||||
"FullHistoryStateInterpreter",
|
||||
"CurrentStepStateInterpreter",
|
||||
"CategoricalActionInterpreter",
|
||||
"TwapRelativeActionInterpreter",
|
||||
"FullHistoryObs",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,10 +7,9 @@ from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
|
||||
from qlib.rl.reward import Reward
|
||||
|
||||
from .simulator_simple import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["PAPenaltyReward"]
|
||||
|
||||
|
||||
|
||||
@@ -3,381 +3,102 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, cast, Generator, List, Optional, Tuple
|
||||
from typing import Generator, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import (
|
||||
dataframe_append,
|
||||
get_common_infra,
|
||||
get_portfolio_and_indicator,
|
||||
get_ticks_slice,
|
||||
price_advantage,
|
||||
)
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import Order
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
from .integration import init_qlib
|
||||
from .state import SAOEState, SAOEStateAdapter
|
||||
from .strategy import SAOEStrategy
|
||||
|
||||
|
||||
class DecomposedStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.execute_order: Optional[Order] = None
|
||||
self.execute_result: List[Tuple[Order, float, float, float]] = []
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
# Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside
|
||||
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
|
||||
# the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner
|
||||
# level strategy through this way.
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
self.execute_order = order
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
self.execute_result = execute_result
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
# this logic is copied from FileOrderStrategy
|
||||
def __init__(
|
||||
self,
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> None:
|
||||
super().__init__(common_infra=common_infra)
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
self._instrument = instrument
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._instrument,
|
||||
amount=self._order.amount,
|
||||
direction=self._order.direction,
|
||||
),
|
||||
]
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
|
||||
|
||||
# TODO: move these to the configuration files
|
||||
FINEST_GRANULARITY = "1min"
|
||||
COARSEST_GRANULARITY = "1day"
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain states of the environment.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.update(...) # in step
|
||||
# get states in get_state from maintainer
|
||||
"""
|
||||
|
||||
def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.position = order.amount
|
||||
self._order = order
|
||||
self._time_per_step = time_per_step
|
||||
self._tick_index = tick_index
|
||||
self._twap_price = twap_price
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
inner_executor: BaseExecutor,
|
||||
inner_strategy: DecomposedStrategy,
|
||||
done: bool,
|
||||
all_indicators: dict,
|
||||
) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
num_step = len(execute_result)
|
||||
|
||||
assert execute_order is not None
|
||||
|
||||
if num_step == 0:
|
||||
market_volume = np.array([])
|
||||
market_price = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
else:
|
||||
market_volume = np.array(
|
||||
inner_executor.trade_exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values
|
||||
deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values
|
||||
market_price = trade_value / deal_amount
|
||||
|
||||
datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:]
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self._order,
|
||||
datetime=datetime_list,
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=all_indicators[self._time_per_step].iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
execute_order,
|
||||
execute_order.start_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if done:
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self._order,
|
||||
self._tick_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
|
||||
class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
order (Order):
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
time_per_step (str):
|
||||
A string to describe the time granularity of each step. Current support "1min", "30min", and "1day"
|
||||
qlib_config (dict):
|
||||
Configuration used to initialize Qlib.
|
||||
inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]):
|
||||
Function used to get the inner level executor.
|
||||
exchange_config (ExchangeConfig):
|
||||
Configuration used to create the Exchange instance.
|
||||
strategy_config
|
||||
Strategy configuration
|
||||
executor_config
|
||||
Executor configuration
|
||||
exchange_config
|
||||
Exchange configuration
|
||||
qlib_config
|
||||
Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
time_per_step: str, # "1min", "30min", "1day"
|
||||
qlib_config: dict,
|
||||
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
strategy_config: dict,
|
||||
executor_config: dict,
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
) -> None:
|
||||
assert time_per_step in ("1min", "30min", "1day")
|
||||
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
|
||||
|
||||
self._order = order
|
||||
self._order_date = pd.Timestamp(order.start_time.date())
|
||||
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
|
||||
self._qlib_config = qlib_config
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._time_per_step = time_per_step
|
||||
self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60)
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config)
|
||||
|
||||
self._done = False
|
||||
def reset(
|
||||
self,
|
||||
order: Order,
|
||||
strategy_config: dict,
|
||||
executor_config: dict,
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
) -> None:
|
||||
if qlib_config is not None:
|
||||
init_qlib(qlib_config, part="skip")
|
||||
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
self.reset(self._order)
|
||||
|
||||
def reset(self, order: Order) -> None:
|
||||
instrument = order.stock_id
|
||||
|
||||
# TODO: Check this logic. Make sure we need to do this every time we reset the simulator.
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_date=pd.Timestamp(self._order_date),
|
||||
codes=[instrument],
|
||||
strategy, self._executor = get_strategy_executor(
|
||||
start_time=order.date,
|
||||
end_time=order.date + pd.DateOffset(1),
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=1e12,
|
||||
exchange_kwargs=exchange_config,
|
||||
pos_type="InfPosition",
|
||||
)
|
||||
|
||||
# TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment.
|
||||
# TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and
|
||||
# TODO: code between backtesting and training.
|
||||
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
|
||||
self._executor = NestedExecutor(
|
||||
time_per_step=COARSEST_GRANULARITY,
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
assert isinstance(self._executor, NestedExecutor)
|
||||
|
||||
self._collect_data_loop = collect_data_loop(
|
||||
start_time=order.date,
|
||||
end_time=order.date,
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
)
|
||||
|
||||
exchange = self._inner_executor.trade_exchange
|
||||
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
|
||||
self._ticks_for_order = get_ticks_slice(
|
||||
self._ticks_index,
|
||||
self._order.start_time,
|
||||
self._order.end_time,
|
||||
include_end=True,
|
||||
)
|
||||
|
||||
self._backtest_data = QlibIntradayBacktestData(
|
||||
order=self._order,
|
||||
exchange=exchange,
|
||||
start_time=self._ticks_for_order[0],
|
||||
end_time=self._ticks_for_order[-1],
|
||||
)
|
||||
|
||||
self.twap_price = self._backtest_data.get_deal_price().mean()
|
||||
|
||||
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
|
||||
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
self._iter_strategy(action=None)
|
||||
self._done = False
|
||||
self._last_yielded_saoe_strategy = self._iter_strategy(action=None)
|
||||
|
||||
self._maintainer = StateMaintainer(
|
||||
order=self._order,
|
||||
time_per_step=self._time_per_step,
|
||||
tick_index=self._ticks_index,
|
||||
twap_price=self.twap_price,
|
||||
)
|
||||
self._order = order
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield DecomposedStrategy."""
|
||||
def _get_adapter(self) -> SAOEStateAdapter:
|
||||
return self._last_yielded_saoe_strategy.adapter_dict[self._order.key_by_day]
|
||||
|
||||
@property
|
||||
def twap_price(self) -> float:
|
||||
return self._get_adapter().twap_price
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> SAOEStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield SAOEStrategy."""
|
||||
assert self._collect_data_loop is not None
|
||||
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, DecomposedStrategy):
|
||||
while not isinstance(strategy, SAOEStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, DecomposedStrategy)
|
||||
assert isinstance(strategy, SAOEStrategy)
|
||||
return strategy
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
@@ -389,36 +110,17 @@ class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
|
||||
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
|
||||
"""
|
||||
|
||||
assert not self._done, "Simulator has already done!"
|
||||
assert not self.done(), "Simulator has already done!"
|
||||
|
||||
try:
|
||||
self._iter_strategy(action=action)
|
||||
self._last_yielded_saoe_strategy = self._iter_strategy(action=action)
|
||||
except StopIteration:
|
||||
self._done = True
|
||||
pass
|
||||
|
||||
assert self._executor is not None
|
||||
_, all_indicators = get_portfolio_and_indicator(self._executor)
|
||||
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
done=self._done,
|
||||
all_indicators=all_indicators,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self._order,
|
||||
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
|
||||
position=self._maintainer.position,
|
||||
history_exec=self._maintainer.history_exec,
|
||||
history_steps=self._maintainer.history_steps,
|
||||
metrics=self._maintainer.metrics,
|
||||
backtest_data=self._backtest_data,
|
||||
ticks_per_step=self._ticks_per_step,
|
||||
ticks_index=self._ticks_index,
|
||||
ticks_for_order=self._ticks_for_order,
|
||||
)
|
||||
return self._get_adapter().saoe_state
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
return self._executor.finished()
|
||||
|
||||
@@ -4,107 +4,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, cast
|
||||
from typing import Any, cast, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .state import SAOEMetrics, SAOEState
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
|
||||
|
||||
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
|
||||
|
||||
class SAOEMetrics(TypedDict):
|
||||
"""Metrics for SAOE accumulated for a "period".
|
||||
It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
The type hints are for single elements. In lots of times, they can be vectorized.
|
||||
For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
|
||||
"""
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
Unit is BP (basis point, 1/10000)."""
|
||||
|
||||
|
||||
class SAOEState(NamedTuple):
|
||||
"""Data structure holding a state for SAOE simulator."""
|
||||
|
||||
order: Order
|
||||
"""The order we are dealing with."""
|
||||
cur_time: pd.Timestamp
|
||||
"""Current time, e.g., 9:30."""
|
||||
position: float
|
||||
"""Current remaining volume to execute."""
|
||||
history_exec: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_exec`."""
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
"""Backtest data is included in the state.
|
||||
Actually, only the time index of this data is needed, at this moment.
|
||||
I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
|
||||
Interpreter can use this as they wish, but they should be careful not to leak future data.
|
||||
"""
|
||||
|
||||
ticks_per_step: int
|
||||
"""How many ticks for each step."""
|
||||
ticks_index: pd.DatetimeIndex
|
||||
"""Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
|
||||
ticks_for_order: pd.DatetimeIndex
|
||||
"""Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
|
||||
__all__ = ["SingleAssetOrderExecution"]
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
@@ -326,8 +240,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
next_time = self._next_time()
|
||||
|
||||
# get the backtest data for next interval
|
||||
self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
|
||||
self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
|
||||
self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - EPS_T].to_numpy()
|
||||
self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - EPS_T].to_numpy()
|
||||
|
||||
assert self.market_vol is not None and self.market_price is not None
|
||||
|
||||
@@ -380,7 +294,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
|
||||
def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
end = end - EPS_T
|
||||
return self.ticks_index[self.ticks_index.slice_indexer(start, end)]
|
||||
|
||||
@staticmethod
|
||||
@@ -391,14 +305,11 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
return pd.concat([df, other_df], axis=0)
|
||||
|
||||
|
||||
_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
exec_price: float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
) -> float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
@@ -414,4 +325,4 @@ def price_advantage(
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
return cast(float_or_ndarray, res_wo_nan)
|
||||
|
||||
334
qlib/rl/order_execution/state.py
Normal file
334
qlib/rl/order_execution/state.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.constant import EPS, ONE_MIN, REG_CN
|
||||
from qlib.rl.data.exchange_wrapper import IntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import BaseIntradayBacktestData
|
||||
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
def _get_all_timestamps(
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
granularity: pd.Timedelta = ONE_MIN,
|
||||
include_end: bool = True,
|
||||
) -> pd.DatetimeIndex:
|
||||
ret = []
|
||||
while start <= end:
|
||||
ret.append(start)
|
||||
start += granularity
|
||||
|
||||
if ret[-1] > end:
|
||||
ret.pop()
|
||||
if ret[-1] == end and not include_end:
|
||||
ret.pop()
|
||||
return pd.DatetimeIndex(ret)
|
||||
|
||||
|
||||
class SAOEStateAdapter:
|
||||
"""
|
||||
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
|
||||
according to the execution results with additional information acquired from executors & exchange. For example,
|
||||
it gets the dealt order amount from execution results, and get the corresponding market price / volume from
|
||||
exchange.
|
||||
|
||||
Example usage::
|
||||
|
||||
adapter = SAOEStateAdapter(...)
|
||||
adapter.update(...)
|
||||
state = adapter.saoe_state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
executor: BaseExecutor,
|
||||
exchange: Exchange,
|
||||
ticks_per_step: int,
|
||||
backtest_data: IntradayBacktestData,
|
||||
) -> None:
|
||||
self.position = order.amount
|
||||
self.order = order
|
||||
self.executor = executor
|
||||
self.exchange = exchange
|
||||
self.backtest_data = backtest_data
|
||||
|
||||
self.twap_price = self.backtest_data.get_deal_price().mean()
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
|
||||
self.ticks_per_step = ticks_per_step
|
||||
|
||||
def _next_time(self) -> pd.Timestamp:
|
||||
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
|
||||
next_loc = current_loc + self.ticks_per_step
|
||||
next_loc = next_loc - next_loc % self.ticks_per_step
|
||||
if (
|
||||
next_loc < len(self.backtest_data.ticks_index)
|
||||
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
|
||||
):
|
||||
return self.backtest_data.ticks_index[next_loc]
|
||||
else:
|
||||
return self.order.end_time
|
||||
|
||||
def update(
|
||||
self,
|
||||
execute_result: list,
|
||||
last_step_range: Tuple[int, int],
|
||||
) -> None:
|
||||
last_step_size = last_step_range[1] - last_step_range[0] + 1
|
||||
start_time = self.backtest_data.ticks_index[last_step_range[0]]
|
||||
end_time = self.backtest_data.ticks_index[last_step_range[1]]
|
||||
|
||||
exec_vol = np.zeros(last_step_size)
|
||||
for order, _, __, ___ in execute_result:
|
||||
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
|
||||
exec_vol[idx - last_step_range[0]] = order.deal_amount
|
||||
|
||||
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
|
||||
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
|
||||
exec_vol *= self.position / (exec_vol.sum())
|
||||
|
||||
market_volume = np.array(
|
||||
self.exchange.get_volume(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
),
|
||||
).reshape(-1)
|
||||
|
||||
market_price = np.array(
|
||||
self.exchange.get_deal_price(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
direction=self.order.direction,
|
||||
),
|
||||
).reshape(-1)
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
# Get data from the current level executor's indicator
|
||||
current_trade_account = self.executor.trade_account
|
||||
current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self.order,
|
||||
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=current_df.iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.cur_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
self.cur_time = self._next_time()
|
||||
|
||||
def generate_metrics_after_done(self) -> None:
|
||||
"""Generate metrics once the upper level execution is done"""
|
||||
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.backtest_data.ticks_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.DatetimeIndex,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, order.direction),
|
||||
)
|
||||
|
||||
@property
|
||||
def saoe_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self.order,
|
||||
cur_time=self.cur_time,
|
||||
position=self.position,
|
||||
history_exec=self.history_exec,
|
||||
history_steps=self.history_steps,
|
||||
metrics=self.metrics,
|
||||
backtest_data=self.backtest_data,
|
||||
ticks_per_step=self.ticks_per_step,
|
||||
ticks_index=self.backtest_data.ticks_index,
|
||||
ticks_for_order=self.backtest_data.ticks_for_order,
|
||||
)
|
||||
|
||||
|
||||
class SAOEMetrics(TypedDict):
|
||||
"""Metrics for SAOE accumulated for a "period".
|
||||
It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
The type hints are for single elements. In lots of times, they can be vectorized.
|
||||
For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
|
||||
"""
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
Unit is BP (basis point, 1/10000)."""
|
||||
|
||||
|
||||
class SAOEState(NamedTuple):
|
||||
"""Data structure holding a state for SAOE simulator."""
|
||||
|
||||
order: Order
|
||||
"""The order we are dealing with."""
|
||||
cur_time: pd.Timestamp
|
||||
"""Current time, e.g., 9:30."""
|
||||
position: float
|
||||
"""Current remaining volume to execute."""
|
||||
history_exec: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_exec`."""
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: BaseIntradayBacktestData
|
||||
"""Backtest data is included in the state.
|
||||
Actually, only the time index of this data is needed, at this moment.
|
||||
I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
|
||||
Interpreter can use this as they wish, but they should be careful not to leak future data.
|
||||
"""
|
||||
|
||||
ticks_per_step: int
|
||||
"""How many ticks for each step."""
|
||||
ticks_index: pd.DatetimeIndex
|
||||
"""Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
|
||||
ticks_for_order: pd.DatetimeIndex
|
||||
"""Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
|
||||
148
qlib/rl/order_execution/strategy.py
Normal file
148
qlib/rl/order_execution/strategy.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from types import GeneratorType
|
||||
from typing import Any, Optional, Union, cast, Dict, Generator
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, Order
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import LevelInfrastructure
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data
|
||||
from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState
|
||||
from qlib.strategy.base import RLStrategy
|
||||
|
||||
|
||||
class SAOEStrategy(RLStrategy):
|
||||
"""RL-based strategies that use SAOEState as state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: object, # TODO: add accurate typehint later.
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super(SAOEStrategy, self).__init__(
|
||||
policy=policy,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
level_infra=level_infra,
|
||||
common_infra=common_infra,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
|
||||
self._last_step_range = (0, 0)
|
||||
|
||||
def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
|
||||
backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range)
|
||||
|
||||
return SAOEStateAdapter(
|
||||
order=order,
|
||||
executor=self.executor,
|
||||
exchange=self.trade_exchange,
|
||||
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
|
||||
backtest_data=backtest_data,
|
||||
)
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
|
||||
super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
|
||||
self.adapter_dict = {}
|
||||
self._last_step_range = (0, 0)
|
||||
|
||||
if outer_trade_decision is not None and not outer_trade_decision.empty():
|
||||
trade_range = outer_trade_decision.trade_range
|
||||
assert trade_range is not None
|
||||
|
||||
self.adapter_dict = {}
|
||||
for decision in outer_trade_decision.get_decision():
|
||||
order = cast(Order, decision)
|
||||
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range)
|
||||
|
||||
def get_saoe_state_by_order(self, order: Order) -> SAOEState:
|
||||
return self.adapter_dict[order.key_by_day].saoe_state
|
||||
|
||||
def post_upper_level_exe_step(self) -> None:
|
||||
for adapter in self.adapter_dict.values():
|
||||
adapter.generate_metrics_after_done()
|
||||
|
||||
def post_exe_step(self, execute_result: Optional[list]) -> None:
|
||||
last_step_length = self._last_step_range[1] - self._last_step_range[0]
|
||||
if last_step_length <= 0:
|
||||
assert not execute_result
|
||||
return
|
||||
|
||||
results = collections.defaultdict(list)
|
||||
if execute_result is not None:
|
||||
for e in execute_result:
|
||||
results[e[0].key_by_day].append(e)
|
||||
|
||||
for key, adapter in self.adapter_dict.items():
|
||||
adapter.update(results[key], self._last_step_range)
|
||||
|
||||
def generate_trade_decision(
|
||||
self,
|
||||
execute_result: list = None,
|
||||
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
|
||||
"""
|
||||
For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated.
|
||||
This operation should be invisible to developers, so we implement it in `generate_trade_decision()`
|
||||
The concrete logic to generate decisions should be implemented in `_generate_trade_decision()`.
|
||||
In other words, all subclass of `SAOEStrategy` should overwrite `_generate_trade_decision()` instead of
|
||||
`generate_trade_decision()`.
|
||||
"""
|
||||
self._last_step_range = self.get_data_cal_avail_range(rtype="step")
|
||||
|
||||
decision = self._generate_trade_decision(execute_result)
|
||||
if isinstance(decision, GeneratorType):
|
||||
decision = yield from decision
|
||||
|
||||
return decision
|
||||
|
||||
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProxySAOEStrategy(SAOEStrategy):
|
||||
"""Proxy strategy that uses SAOEState. It is called a 'proxy' strategy because it does not make any decisions
|
||||
by itself. Instead, when the strategy is required to generate a decision, it will yield the environment's
|
||||
information and let the outside agents to make the decision. Please refer to `_generate_trade_decision` for
|
||||
more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
# Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside
|
||||
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
|
||||
# the item will be captured by `exec_vol`. The outside policy could communicate with the inner
|
||||
# level strategy through this way.
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
|
||||
assert isinstance(outer_trade_decision, TradeDecisionWO)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
@@ -3,52 +3,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_date: pd.Timestamp,
|
||||
codes: List[str],
|
||||
cash_limit: float = None,
|
||||
) -> CommonInfrastructure:
|
||||
# need to specify a range here for acceleration
|
||||
if cash_limit is None:
|
||||
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
|
||||
else:
|
||||
trade_account = Account(
|
||||
init_cash=cash_limit,
|
||||
benchmark_config={},
|
||||
pos_type="Position",
|
||||
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
|
||||
)
|
||||
|
||||
exchange = get_exchange(
|
||||
codes=codes,
|
||||
freq="1min",
|
||||
limit_threshold=config.limit_threshold,
|
||||
deal_price=config.deal_price,
|
||||
open_cost=config.open_cost,
|
||||
close_cost=config.close_cost,
|
||||
min_cost=config.min_cost if config.trade_unit is not None else 0,
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold,
|
||||
)
|
||||
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.constant import EPS_T, float_or_ndarray
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
@@ -58,7 +20,7 @@ def get_ticks_slice(
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
@@ -72,10 +34,10 @@ def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
exec_price: float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
) -> float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
@@ -91,21 +53,11 @@ def price_advantage(
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
return cast(float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
|
||||
all_executors = executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
|
||||
return all_portfolio_metrics, all_indicators
|
||||
def get_simulator_executor(executor: BaseExecutor) -> SimulatorExecutor:
|
||||
while isinstance(executor, NestedExecutor):
|
||||
executor = executor.inner_executor
|
||||
assert isinstance(executor, SimulatorExecutor)
|
||||
return executor
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TODO: find a better way to organize contents under this module.
|
||||
31
qlib/rl/strategy/single_order.py
Normal file
31
qlib/rl/strategy/single_order.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
"""Strategy used to generate a trade decision with exactly one order."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
trade_range: TradeRange = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._order.stock_id,
|
||||
amount=self._order.amount,
|
||||
direction=self._order.direction,
|
||||
),
|
||||
]
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
@@ -11,13 +11,14 @@ from __future__ import annotations
|
||||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .log import LogWriter
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Generator, Optional, TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
@@ -55,6 +56,10 @@ class BaseStrategy:
|
||||
self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
self._trade_exchange = trade_exchange
|
||||
|
||||
@property
|
||||
def executor(self) -> BaseExecutor:
|
||||
return self.level_infra.get("executor")
|
||||
|
||||
@property
|
||||
def trade_calendar(self) -> TradeCalendarManager:
|
||||
return self.level_infra.get("trade_calendar")
|
||||
@@ -85,7 +90,7 @@ class BaseStrategy:
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
**kwargs, # TODO: remove this?
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
@@ -136,47 +141,6 @@ class BaseStrategy:
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def update_trade_decision(
|
||||
trade_decision: BaseTradeDecision,
|
||||
trade_calendar: TradeCalendarManager,
|
||||
) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
the trade decision that will be updated
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision:
|
||||
"""
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
# FIXME: do not define this method as an abstract one since it is never implemented
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
"""
|
||||
A method for updating the outer_trade_decision.
|
||||
The outer strategy may change its decision during updating.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
"""
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method")
|
||||
|
||||
# helper methods: not necessary but for convenience
|
||||
def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
"""
|
||||
@@ -207,7 +171,58 @@ class BaseStrategy:
|
||||
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
"""
|
||||
The following methods are used to do cross-level communications in nested execution.
|
||||
You do not need to care about them if you are implementing a single-level execution.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def update_trade_decision(
|
||||
trade_decision: BaseTradeDecision,
|
||||
trade_calendar: TradeCalendarManager,
|
||||
) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
the trade decision that will be updated
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision:
|
||||
"""
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
"""
|
||||
A method for updating the outer_trade_decision.
|
||||
The outer strategy may change its decision during updating.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
"""
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
return outer_trade_decision
|
||||
|
||||
def post_upper_level_exe_step(self) -> None:
|
||||
"""
|
||||
A hook for doing sth after the upper level executor finished its execution (for example, finalize
|
||||
the metrics collection).
|
||||
"""
|
||||
|
||||
def post_exe_step(self, execute_result: Optional[list]) -> None:
|
||||
"""
|
||||
A hook for doing sth after the corresponding executor finished its execution.
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
@@ -32,23 +31,71 @@ def get_order() -> Order:
|
||||
)
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
|
||||
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
|
||||
return NestedExecutor(
|
||||
time_per_step=time_per_step,
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
verbose=False,
|
||||
trade_type=SimulatorExecutor.TT_SERIAL,
|
||||
generate_report=False,
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
),
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
)
|
||||
def get_configs(order: Order) -> Tuple[dict, dict, dict]:
|
||||
strategy_config = {
|
||||
"class": "SingleOrderStrategy",
|
||||
"module_path": "qlib.rl.strategy.single_order",
|
||||
"kwargs": {
|
||||
"order": order,
|
||||
"trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()),
|
||||
},
|
||||
}
|
||||
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "1day",
|
||||
"inner_strategy": {"class": "ProxySAOEStrategy", "module_path": "qlib.rl.order_execution.strategy"},
|
||||
"track_data": True,
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "30min",
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "1min",
|
||||
"verbose": False,
|
||||
"trade_type": SimulatorExecutor.TT_SERIAL,
|
||||
"generate_report": False,
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"start_time": pd.Timestamp(order.start_time.date()),
|
||||
"end_time": pd.Timestamp(order.start_time.date()),
|
||||
},
|
||||
}
|
||||
|
||||
exchange_config = {
|
||||
"freq": "1min",
|
||||
"codes": [order.stock_id],
|
||||
"limit_threshold": ("$ask == 0", "$bid == 0"),
|
||||
"deal_price": ("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
"volume_threshold": {
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": None,
|
||||
}
|
||||
|
||||
return strategy_config, executor_config, exchange_config
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecution:
|
||||
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
|
||||
|
||||
# fmt: off
|
||||
@@ -67,27 +114,13 @@ def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=("$ask == 0", "$bid == 0"),
|
||||
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
volume_threshold={
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
cash_limit=None,
|
||||
generate_report=False,
|
||||
)
|
||||
strategy_config, executor_config, exchange_config = get_configs(order)
|
||||
|
||||
return SingleAssetOrderExecutionQlib(
|
||||
return SingleAssetOrderExecution(
|
||||
order=order,
|
||||
time_per_step="30min",
|
||||
qlib_config=qlib_config,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
strategy_config=strategy_config,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
|
||||
@@ -115,12 +148,12 @@ def test_simulator_first_step():
|
||||
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
|
||||
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
|
||||
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
|
||||
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
|
||||
assert is_close(state.history_exec["ffr"].iloc[0], AMOUNT / TOTAL_POSITION / 30)
|
||||
|
||||
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
|
||||
assert state.history_steps["amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["ffr"].iloc[0] == 1.0
|
||||
assert state.history_steps["ffr"].iloc[0] == AMOUNT / TOTAL_POSITION
|
||||
assert is_close(
|
||||
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
|
||||
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
|
||||
@@ -169,9 +202,3 @@ def test_interpreter() -> None:
|
||||
position_history.append(state.position)
|
||||
|
||||
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simulator_first_step()
|
||||
test_simulator_stop_twap()
|
||||
test_interpreter()
|
||||
|
||||
Reference in New Issue
Block a user