From 934840146b0862e07b0b37dba1f34b5a13eb799b Mon Sep 17 00:00:00 2001 From: Default Date: Mon, 27 Jun 2022 15:50:48 +0800 Subject: [PATCH] Simulator & action interpreter --- qlib/backtest/__init__.py | 2 +- qlib/backtest/decision.py | 6 +- .../from_neutrader/__init__.py | 0 .../order_execution/from_neutrader/config.py | 26 ++ .../from_neutrader/executor.py | 11 + .../order_execution/from_neutrader/feature.py | 164 ++++++++++++ .../from_neutrader/highfreq_ops.py | 223 ++++++++++++++++ .../order_execution/from_neutrader/state.py | 251 ++++++++++++++++++ .../from_neutrader/state_maintainer.py | 162 +++++++++++ .../from_neutrader/strategy.py | 91 +++++++ qlib/rl/order_execution/simulator_qlib.py | 129 +++++---- qlib/rl/order_execution/tests/__init__.py | 0 .../tests/test_simulator_qlib.py | 133 ++++++++++ 13 files changed, 1143 insertions(+), 55 deletions(-) create mode 100644 qlib/rl/order_execution/from_neutrader/__init__.py create mode 100644 qlib/rl/order_execution/from_neutrader/config.py create mode 100644 qlib/rl/order_execution/from_neutrader/executor.py create mode 100644 qlib/rl/order_execution/from_neutrader/feature.py create mode 100644 qlib/rl/order_execution/from_neutrader/highfreq_ops.py create mode 100644 qlib/rl/order_execution/from_neutrader/state.py create mode 100644 qlib/rl/order_execution/from_neutrader/state_maintainer.py create mode 100644 qlib/rl/order_execution/from_neutrader/strategy.py create mode 100644 qlib/rl/order_execution/tests/__init__.py create mode 100644 qlib/rl/order_execution/tests/test_simulator_qlib.py diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 252631253..20fbe14a4 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -349,4 +349,4 @@ def format_decisions( return res -__all__ = ["Order", "backtest", "BaseExecutor", "CommonInfrastructure"] +__all__ = ["Order", "backtest"] diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 9a6084214..bff59c11f 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -179,8 +179,8 @@ class OrderHelper: return Order( stock_id=code, amount=amount, - start_time=start_time if start_time is not None else pd.Timestamp(start_time), - end_time=end_time if end_time is not None else pd.Timestamp(end_time), + start_time=None if start_time is None else pd.Timestamp(start_time), + end_time=None if end_time is None else pd.Timestamp(end_time), direction=direction, ) @@ -530,7 +530,7 @@ class TradeDecisionWO(BaseTradeDecision): Besides, the time_range is also included. """ - def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None): + def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None): super().__init__(strategy, trade_range=trade_range) self.order_list = order_list start, end = strategy.trade_calendar.get_step_time() diff --git a/qlib/rl/order_execution/from_neutrader/__init__.py b/qlib/rl/order_execution/from_neutrader/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/rl/order_execution/from_neutrader/config.py b/qlib/rl/order_execution/from_neutrader/config.py new file mode 100644 index 000000000..8d43dc76a --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/config.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union + + +@dataclass +class RuntimeConfig: + seed: int = 42 + output_dir: Optional[Path] = None + checkpoint_dir: Optional[Path] = None + tb_log_dir: Optional[Path] = None + debug: bool = False + use_cuda: bool = True + + +@dataclass +class ExchangeConfig: + limit_threshold: Union[float, Tuple[str, str]] + deal_price: Union[str, Tuple[str, str]] + volume_threshold: dict + open_cost: float = 0.0005 + close_cost: float = 0.0015 + min_cost: float = 5. + trade_unit: Optional[float] = 100. + cash_limit: Optional[Union[Path, float]] = None + generate_report: bool = False diff --git a/qlib/rl/order_execution/from_neutrader/executor.py b/qlib/rl/order_execution/from_neutrader/executor.py new file mode 100644 index 000000000..e2aebbed1 --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/executor.py @@ -0,0 +1,11 @@ +from typing import List + +from qlib.backtest.executor import NestedExecutor +from .strategy import RLStrategyBase + + +class RLNestedExecutor(NestedExecutor): + # RL nested executor + def post_inner_exe_step(self, inner_exe_res: List[object]) -> None: + if isinstance(self.inner_strategy, RLStrategyBase): + self.inner_strategy.post_exe_step(inner_exe_res) diff --git a/qlib/rl/order_execution/from_neutrader/feature.py b/qlib/rl/order_execution/from_neutrader/feature.py new file mode 100644 index 000000000..80157b461 --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/feature.py @@ -0,0 +1,164 @@ +import collections +from dataclasses import dataclass + +import numpy as np +import pickle +from pathlib import Path +from typing import Optional, List + +import pandas as pd +import qlib +from .highfreq_ops import DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut +from qlib.contrib.ops.high_freq import DayCumsum +from qlib.config import REG_CN +from qlib.data.dataset import DatasetH + + +@dataclass +class QlibConfig: + provider_uri_day: Path + provider_uri_1min: Path + feature_root_dir: Path + feature_columns_today: List[str] + feature_columns_yesterday: List[str] + + +_dataset = None + + +class LRUCache: + def __init__(self, pool_size: int = 200): + self.pool_size = pool_size + self.contents = dict() + self.keys = collections.deque() + + def put(self, key, item): + if self.has(key): + self.keys.remove(key) + self.keys.append(key) + self.contents[key] = item + while len(self.contents) > self.pool_size: + self.contents.pop(self.keys.popleft()) + + def get(self, key): + return self.contents[key] + + def has(self, key): + return key in self.contents + + +class DataWrapper: + + def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH, + columns_today: List[str], columns_yesterday: List[str], _internal: bool = False): + assert _internal, 'Init function of data wrapper is for internal use only.' + + self.feature_dataset = feature_dataset + self.backtest_dataset = backtest_dataset + self.columns_today = columns_today + self.columns_yesterday = columns_yesterday + + self.feature_cache = LRUCache() + self.backtest_cache = LRUCache() + + def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False): + start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) + + dataset = self.backtest_dataset if backtest else self.feature_dataset + + if backtest: + dataset = self.backtest_dataset + cache = self.backtest_cache + else: + dataset = self.feature_dataset + cache = self.feature_cache + + if cache.has((start_time, end_time, stock_id)): + return cache.get((start_time, end_time, stock_id)) + data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) + cache.put((start_time, end_time, stock_id), data) + return data + + +def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None: + global _dataset + + provider_uri_map = { + "day": config.provider_uri_day.as_posix(), + "1min": config.provider_uri_1min.as_posix(), + } + qlib.init( + region=REG_CN, + auto_mount=False, + custom_ops=[DayLast, FFillNan, BFillNan, + Date, Select, IsNull, IsInf, Cut, DayCumsum], + expression_cache=None, + calendar_provider={ + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + feature_provider={ + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + provider_uri=provider_uri_map, + kernels=1, + redis_port=-1, + clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance + ) + + # this won't work if it's put outside in case of multiprocessing + from qlib.data import D + + if part is None: + feature_path = config.feature_root_dir / 'feature.pkl' + backtest_path = config.feature_root_dir / 'backtest.pkl' + else: + feature_path = config.feature_root_dir / 'feature' / (part + '.pkl') + backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl') + + with feature_path.open('rb') as f: + print(feature_path) + feature_dataset = pickle.load(f) + with backtest_path.open('rb') as f: + backtest_dataset = pickle.load(f) + + _dataset = DataWrapper( + feature_dataset, + backtest_dataset, + config.feature_columns_today, + config.feature_columns_yesterday, + _internal=True + ) + + +def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False): + assert _dataset is not None, 'You must call init_qlib() before doing this.' + + if backtest: + fields = ['$close', '$volume'] + else: + fields = _dataset.columns_yesterday if yesterday else _dataset.columns_today + + data = _dataset.get(stock_id, date, backtest) + if data is None or len(data) == 0: + # create a fake index, but RL doesn't care about index + data = pd.DataFrame(0., index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here + else: + data = data.rename(columns={c: c.rstrip('0') for c in data.columns}) + data = data[fields] + return data diff --git a/qlib/rl/order_execution/from_neutrader/highfreq_ops.py b/qlib/rl/order_execution/from_neutrader/highfreq_ops.py new file mode 100644 index 000000000..cf220a384 --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/highfreq_ops.py @@ -0,0 +1,223 @@ +import numpy as np +import pandas as pd + +from qlib.data.cache import H +from qlib.data.data import Cal +from qlib.data.ops import ElemOperator, PairOperator + + +def get_calendar_day(freq="day", future=False): + """Load High-Freq Calendar Date Using Memcache. + + Parameters + ---------- + freq : str + frequency of read calendar file. + future : bool + whether including future trading day. + + Returns + ------- + _calendar: + array of date. + """ + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar = H["c"][flag] + else: + _calendar = np.array( + list(map(lambda x: x.date(), Cal.load_calendar(freq, future)))) + H["c"][flag] = _calendar + return _calendar + + +def get_calendar_minute(freq='day', future=False): + """Load High-Freq Calendar Minute Using Memcache""" + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar = H["c"][flag] + else: + _calendar = np.array( + list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future)))) + H["c"][flag] = _calendar + return _calendar + + +class DayLast(ElemOperator): + """DayLast Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value equals the last value of its day + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return series.groupby(_calendar[series.index]).transform("last") + + +class FFillNan(ElemOperator): + """FFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a forward fill nan feature + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="ffill") + + +class BFillNan(ElemOperator): + """BFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a backfoward fill nan feature + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="bfill") + + +class Date(ElemOperator): + """Date Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value is the date corresponding to feature.index + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return pd.Series(_calendar[series.index], index=series.index) + + +class Select(PairOperator): + """Select Operator + + Parameters + ---------- + feature_left : Expression + feature instance, select condition + feature_right : Expression + feature instance, select value + + Returns + ---------- + feature: + value(feature_right) that meets the condition(feature_left) + + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series_condition = self.feature_left.load( + instrument, start_index, end_index, freq) + series_feature = self.feature_right.load( + instrument, start_index, end_index, freq) + return series_feature.loc[series_condition] + + +class IsNull(ElemOperator): + """IsNull Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + A series indicating whether the feature is nan + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.isnull() + + +class IsInf(ElemOperator): + """IsInf Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + A series indicating whether the feature is inf + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return np.isinf(series) + + +class Cut(ElemOperator): + """Cut Operator + + Parameters + ---------- + feature : Expression + feature instance + l : int + l > 0, delete the first l elements of feature (default is None, which means 0) + r : int + r < 0, delete the last -r elements of feature (default is None, which means 0) + Returns + ---------- + feature: + A series with the first l and last -r elements deleted from the feature. + Note: It is deleted from the raw data, not the sliced data + """ + + def __init__(self, feature, left=None, right=None): + self.left = left + self.right = right + if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0): + raise ValueError("Cut operator l shoud > 0 and r should < 0") + + super(Cut, self).__init__(feature) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.iloc[self.left: self.right] + + def get_extended_window_size(self): + ll = 0 if self.left is None else self.left + rr = 0 if self.right is None else abs(self.right) + lft_etd, rght_etd = self.feature.get_extended_window_size() + lft_etd = lft_etd + ll + rght_etd = rght_etd + rr + return lft_etd, rght_etd diff --git a/qlib/rl/order_execution/from_neutrader/state.py b/qlib/rl/order_execution/from_neutrader/state.py new file mode 100644 index 000000000..c08e49c01 --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/state.py @@ -0,0 +1,251 @@ +import abc +import math +from dataclasses import dataclass, field, fields +from enum import Enum +from typing import Callable, Literal, Optional, Tuple + +import numpy as np +import pandas as pd + +EPSILON = 1e-7 + + +class FlowDirection(str, Enum): + ACQUIRE = "acquire" + LIQUIDATE = "liquidate" + + +def _round_time(time: int, granularity: int) -> int: + return time - time % granularity + + +@dataclass +class BaseEpisodicState(abc.ABC): + """ + Base class for episodic states. + """ + + # requirements + start_time: int + end_time: int + time_per_step: int + vol_limit: Optional[float] # TODO: meaning? + price_func: Callable[[str], np.ndarray] # TODO: meaning? + volume_func: Callable[[], np.ndarray] # TODO: meaning? + on_step_end: Optional[Callable[..., None]] # TODO: meaning? + on_episode_end: Optional[Callable[..., None]] # TODO: meaning? + asset_num: int # TODO: meaning? + + # agent states + num_step: int = field(init=False) # Number of steps + cur_time: int = field(init=False) # Current time + cur_step: int = field(init=False, default=0) + exec_vol: Optional[np.ndarray] = field(init=False, default=None) # Execution history + last_step_duration: int = field(init=False) + position: float = field(init=False) + position_history: np.ndarray = field(init=False) + + def __post_init__(self) -> None: + self.cur_time = self.start_time + rounded_start_time = _round_time(self.start_time, self.time_per_step) + + # TODO: why not rounding end time? + self.num_step = math.floor((self.end_time - rounded_start_time) / self.time_per_step) + + def logs(self) -> dict: + # Base logging information shared across all subclasses. + # You can call logs = super().logs() to get these default logs and use logs.update(...) to add other logging + # information or override it completely to remove these logging fields. + return { + "logs": { + "stop_time": self.cur_time - self.start_time, + "stop_step": self.cur_step, + }, + "history": { + "volume": self.execution_history(), + }, + } + + def execution_history(self) -> np.ndarray: + return np.pad(self.exec_vol, (0, self.end_time - self.start_time - len(self.exec_vol))) + + def next_duration(self) -> int: + left, right = self.next_interval() + return right - left + + def next_interval(self) -> Tuple[int, int]: + left = _round_time(self.cur_time, self.time_per_step) + right = left + self.time_per_step + return max(left, self.start_time) - self.start_time, min(right, self.end_time) - self.start_time + + @classmethod + def get_init_field_names(cls): + ret = [] + for f in fields(cls): + if f.init: + ret.append(f.name) + return ret + + @abc.abstractmethod + def step(self, *args, **kwargs): + raise NotImplementedError() + + @property + def done(self) -> bool: + return False + + +@dataclass +class IntraDaySingleAssetDataSchema: + """ + In the current context, raw should be a DataFrame with `datetime` as index and + (at least) `$vwap0`, `$volume0`, `$close0` as columns. + `processed` should be a DataFrame of 240x6, which is the same as `processed_prev`. + """ + + date: pd.Timestamp + stock_id: str + start_time: int + end_time: int + target: float + flow_dir: FlowDirection + raw: pd.DataFrame + processed: pd.DataFrame + processed_prev: pd.DataFrame + + def get_price(self, type: Literal['deal', 'close'] = 'deal'): + if type == 'deal': + return self.raw['$price'].values + elif type == 'close': + return self.raw['$close0'].values + + def get_volume(self): + return self.raw['$volume0'].values + + def get_processed_data(self, type: Literal['today', 'yesterday'] = 'today'): + if type == 'today': + return self.processed.to_numpy() + elif type == 'yesterday': + return self.processed_prev.to_numpy() + + +@dataclass +class SAOEEpisodicState(BaseEpisodicState): + """ + Global state of the whole time horizon. + """ + + # requirements + target: float + target_limit: float + flow_dir: FlowDirection + + # calculated statistics + turnover: Optional[float] = field(init=False) + baseline_twap: Optional[float] = field(init=False) + baseline_vwap: Optional[float] = field(init=False) + exec_avg_price: Optional[float] = field(init=False) + pa_twap: Optional[float] = field(init=False) + pa_vwap: Optional[float] = field(init=False) + pa_close: Optional[float] = field(init=False) + fulfill_rate: Optional[float] = field(init=False) + + market_price: np.ndarray = field(init=False) # deal price, might be different from close + market_close: np.ndarray = field(init=False) # close price + market_volume: np.ndarray = field(init=False) + + # NOTE: this is a temporary design to make it compatible with old qlib integration framework. As long as callback + # functions are passed correctly, this field should be removed from this class. + last_interval: Tuple[int, int] = field(default=(0, 0), init=False) + + def __post_init__(self) -> None: + assert self.target >= 0 + assert self.asset_num == 1 + + super().__post_init__() + + self.market_volume = self.volume_func()[self.start_time : self.end_time] + self.market_price = self.price_func("deal")[self.start_time : self.end_time] + self.market_close = self.price_func("close")[self.start_time : self.end_time] + self.position = self.target + self.position_history = np.full((self.num_step + 1), np.nan) + self.position_history[0] = self.position + self.baseline_twap = np.mean(self.market_price) + if self.market_volume.sum() == 0: + self.baseline_vwap = self.baseline_twap + else: + self.baseline_vwap = np.average(self.market_price, weights=self.market_volume) + + def update_stats(self) -> None: + market_price = self.market_price[: len(self.exec_vol)] + self.turnover = (self.exec_vol * market_price).sum() + # exec_vol can be zero + if np.isclose(self.exec_vol.sum(), 0): + self.exec_avg_price = market_price[0] + else: + self.exec_avg_price = np.average(market_price, weights=self.exec_vol) + + self.pa_twap = _price_advantage(self.exec_avg_price, self.baseline_twap, self.flow_dir) + self.pa_vwap = _price_advantage(self.exec_avg_price, self.baseline_vwap, self.flow_dir) + close_average = np.mean(self.market_close) + self.pa_close = _price_advantage(self.exec_avg_price, close_average, self.flow_dir) + + self.fulfill_rate = (self.target - self.position) / self.target_limit + if abs(self.fulfill_rate - 1.0) < EPSILON: + self.fulfill_rate = 1.0 + self.fulfill_rate *= 100 + + def logs(self) -> dict: + logs = super().logs() + logs.update( + { + "logs": { + "turnover": self.turnover, + "baseline_twap": self.baseline_twap, + "baseline_vwap": self.baseline_vwap, + "exec_avg_price": self.exec_avg_price, + "pa_twap": self.pa_twap, + "pa_vwap": self.pa_vwap, + "pa_close": self.pa_close, + "ffr": self.fulfill_rate, + } + } + ) + return logs + + def step(self, exec_vol: np.ndarray) -> None: + l, r = self.next_interval() + self.last_interval = (l, r) + assert 0 <= l < r + self.last_step_duration = len(exec_vol) + self.position -= exec_vol.sum() + assert ( + self.position > -EPSILON and (exec_vol > -EPSILON).all(), + f"Execution volume is invalid: {exec_vol} (position = {self.position})", + ) + self.cur_step += 1 + self.position_history[self.cur_step] = self.position + self.cur_time += self.last_step_duration + if self.cur_step == self.num_step: # Should reach the end of episode + assert self.cur_time == self.end_time + self.exec_vol = exec_vol if self.exec_vol is None else np.concatenate((self.exec_vol, exec_vol)) + + if self.on_step_end is not None: + self.on_step_end(l, r, self) + if self.done: + self.update_stats() + if self.on_episode_end is not None: + self.on_episode_end(self) + + @property + def done(self) -> bool: + return self.position < EPSILON or self.cur_step == self.num_step + + +def _price_advantage(exec_price: float, baseline_price: float, flow: FlowDirection) -> float: + if baseline_price == 0: + return 0.0 + if flow == FlowDirection.ACQUIRE: + return (1 - exec_price / baseline_price) * 10000 + else: + return (exec_price / baseline_price - 1) * 10000 diff --git a/qlib/rl/order_execution/from_neutrader/state_maintainer.py b/qlib/rl/order_execution/from_neutrader/state_maintainer.py new file mode 100644 index 000000000..1c2d1499f --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/state_maintainer.py @@ -0,0 +1,162 @@ +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO +from qlib.backtest.exchange import Exchange +from qlib.constant import REG_CN +from qlib.rl.order_execution.from_neutrader.feature import fetch_features +from qlib.rl.order_execution.from_neutrader.state import FlowDirection, IntraDaySingleAssetDataSchema, SAOEEpisodicState +from qlib.utils.time import get_day_min_idx_range + + +class StateMaintainer: + """ + Maintain neutrader states taking qlib trade decisions as input. + + Example usage:: + + maintainer = StateMaintainer(...) # in reset + maintainer.send_execute_result(execute_result) # in step + # do something here + maintainer.generate_orders(self.get_data_cal_avail_range(rtype='step'), exec_vols) + + The states can be accessed via ``maintianer.states`` and ``maintainer.samples``. + """ + + def __init__( + self, + time_per_step: int, + date: pd.Timestamp, + full_trade_range: Tuple[int, int], + current_step: int, + outer_trade_decision: TradeDecisionWO, + trade_exchange: Exchange, + ) -> None: + # The parameters look very ad-hoc right now + self.states: Dict[Tuple[str, OrderDir], SAOEEpisodicState] = {} # explicitly make it ordered + self.samples: Dict[Tuple[str, OrderDir], IntraDaySingleAssetDataSchema] = {} + self.time_per_step: int = time_per_step + self.start_time, self.end_time = full_trade_range + self.end_time += 1 # plus 1 to align with the semantics in neutrader + self.date: pd.Timestamp = date + self.last_step_length: int = -1 + self.last_step_range: Optional[Tuple[int, int]] = None + + self.order_list: List[Order] = outer_trade_decision.order_list + self.trade_exchange: Exchange = trade_exchange + + self.num_step = ( + self.end_time - (self.start_time - self.start_time % self.time_per_step) - 1 + ) // self.time_per_step + 1 + + for order in self.order_list: + sample = self._fetch_sample_data(order) + state = self._create_single_ep_state(sample, current_step) + self.samples[order.stock_id, order.direction] = sample + self.states[order.stock_id, order.direction] = state + + def _fetch_sample_data(self, order: Order) -> IntraDaySingleAssetDataSchema: + start_time = self.date.replace(hour=0, minute=0, second=0) + end_time = self.date.replace(hour=23, minute=59, second=59) + deal_price = self.trade_exchange.get_deal_price( + stock_id=order.stock_id, start_time=start_time, end_time=end_time, direction=order.direction, method=None, + ) + backtest_data = fetch_features(order.stock_id, self.date, backtest=True) + # HACK: close means deal price here. The logic is implemented in qlib. + backtest_data["$close"] = deal_price.to_series().to_numpy() + feature_today = fetch_features(order.stock_id, self.date) + feature_yesterday = fetch_features(order.stock_id, self.date, yesterday=True) + return IntraDaySingleAssetDataSchema( + date=self.date.date(), + stock_id=order.stock_id, + start_time=self.start_time, + end_time=self.end_time, + target=max(order.amount, 0.0), # prevent target to go to -eps + flow_dir=FlowDirection.LIQUIDATE if order.direction == 0 else FlowDirection.ACQUIRE, + raw=backtest_data, + processed=feature_today, + processed_prev=feature_yesterday, + ) + + def _create_single_ep_state(self, sample: IntraDaySingleAssetDataSchema, cur_step: int) -> SAOEEpisodicState: + market_price = sample.raw["$close"].values + market_vol = sample.raw["$volume"].values + target = sample.target + + # NOTE: Previously, market_price and market_vol are passed into the state initialization directly. Therefore, + # the segment of market_price and market_vol are used instead of the lambda function here using the whole price + # and vol data. + # This refactoring is ONLY EQUIVALENT WHEN start_time/end_time passed into state is equal to + # sample.start_time/end_time. + # If one can confirm that these two are always the same, delete this note, please. + state = SAOEEpisodicState( + self.start_time, + self.end_time, + self.time_per_step, + None, + lambda x: market_price, + lambda: market_vol, + None, + None, + 1, + target, + target, + sample.flow_dir, + ) + state.cur_step = cur_step + assert state.cur_step == 0 + return state + + def _update_single_ep_state( + self, state: SAOEEpisodicState, execute_result: List[Order], length: Optional[int] = None + ) -> None: + if length is not None: + exec_vol = np.zeros(length) + for order, _, __, ___ in execute_result: + idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN) + exec_vol[idx - self.last_step_range[0]] = order.deal_amount + else: + exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result]) + + # sometimes exec_vol gets too large due to the rounding in exchange + # scale the execution volume so that position won't go below 0 + # actually this case is very rare + if exec_vol.sum() > state.position and exec_vol.sum() > 0: + assert exec_vol.sum() < state.position + 1, f"{exec_vol} too large for {state}" + exec_vol *= state.position / (exec_vol.sum()) + + state.step(exec_vol) + + def create_sub_order(self, exec_vol: float, original_order: Order) -> Order: + oh = self.trade_exchange.get_order_helper() + return oh.create(original_order.stock_id, exec_vol, original_order.direction) + + def send_execute_result(self, execute_result: Optional[List[Any]]) -> None: + if self.last_step_length < 0: + assert not execute_result + return + orders = defaultdict(list) + if execute_result is not None: + for e in execute_result: + orders[e[0].stock_id, e[0].direction].append(e) + for (stock_id, direction), state in self.states.items(): + self._update_single_ep_state(state, orders[stock_id, direction], self.last_step_length) + + def generate_orders(self, step_trade_range: Tuple[int, int], exec_vols: List[float]) -> List[Order]: + order_list = [] + + assert len(exec_vols) == len(self.order_list) + for v, o in zip(exec_vols, self.order_list): + if v > 0: + order_list.append(self.create_sub_order(v, o)) + + step_start_time, step_end_time = step_trade_range # inclusive + step_end_time += 1 + + self.last_step_length = step_end_time - step_start_time + self.last_step_range = (step_start_time, step_end_time) + + return order_list diff --git a/qlib/rl/order_execution/from_neutrader/strategy.py b/qlib/rl/order_execution/from_neutrader/strategy.py new file mode 100644 index 000000000..df6199a3b --- /dev/null +++ b/qlib/rl/order_execution/from_neutrader/strategy.py @@ -0,0 +1,91 @@ +from abc import ABCMeta +from typing import Tuple + +import pandas as pd + +from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange +from qlib.backtest.utils import CommonInfrastructure +from qlib.rl.order_execution.from_neutrader.state import IntraDaySingleAssetDataSchema, SAOEEpisodicState +from qlib.rl.order_execution.from_neutrader.state_maintainer import StateMaintainer +from qlib.strategy.base import BaseStrategy + + +class RLStrategyBase(BaseStrategy, metaclass=ABCMeta): + def post_exe_step(self, execute_result: list) -> None: + """ + post process for each step of strategy this is design for RL Strategy, + which require to update the policy state after each step + + NOTE: it is strongly coupled with RLNestedExecutor; + """ + raise NotImplementedError("Please implement the `post_exe_step` method") + + +class DecomposedStrategy(RLStrategyBase): + def __init__(self): + super(DecomposedStrategy, self).__init__() + + def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + time_per_step = int(pd.Timedelta(self.trade_calendar.get_freq()) / pd.Timedelta("1min")) + if outer_trade_decision is not None: + self.maintainer = StateMaintainer( + time_per_step, + self.trade_calendar.get_all_time()[0], + self.get_data_cal_avail_range(), + self.trade_calendar.get_trade_step(), + outer_trade_decision, + self.trade_exchange, + ) + + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: + return outer_trade_decision + + def post_exe_step(self, execute_result): + self.maintainer.send_execute_result(execute_result) + + @property + def sample_state_pair(self) -> Tuple[IntraDaySingleAssetDataSchema, SAOEEpisodicState]: + assert len(self.maintainer.samples) == len(self.maintainer.states) == 1 + return ( + list(self.maintainer.samples.values())[0], + list(self.maintainer.states.values())[0], + ) + + def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + # get a decision from the outmost loop + exec_vol = yield self + + return TradeDecisionWO( + self.maintainer.generate_orders(self.get_data_cal_avail_range(rtype="step"), [exec_vol]), self + ) + + +class SingleOrderStrategy(BaseStrategy): + # this logic is copied from FileOrderStrategy + def __init__( + self, + common_infra: CommonInfrastructure, + order: Order, + trade_range: TradeRange, + instrument: str, + ) -> None: + super().__init__(common_infra=common_infra) + self._order = order + self._trade_range = trade_range + self._instrument = instrument + + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: + return outer_trade_decision + + def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO: + oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() + order_list = [ + oh.create( + code=self._instrument, + amount=self._order.amount, + direction=Order.parse_dir(self._order.direction), + ) + ] + trade_decision = TradeDecisionWO(order_list, self, self._trade_range) + return trade_decision diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 3a6e0079c..7224d34a2 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -2,34 +2,27 @@ # Licensed under the MIT License. """Placeholder for qlib-based simulator.""" -from dataclasses import dataclass -from pathlib import Path -from plistlib import Dict -from typing import Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Union import pandas as pd +from gym.vector.utils import spaces -from qlib.backtest import Account, BaseExecutor, CommonInfrastructure, get_exchange -from qlib.backtest.decision import Order -from qlib.backtest.executor import NestedExecutor +from qlib.backtest import get_exchange +from qlib.backtest.account import Account +from qlib.backtest.decision import Order, TradeRange, TradeRangeByTime +from qlib.backtest.executor import BaseExecutor +from qlib.backtest.utils import CommonInfrastructure from qlib.config import QlibConfig -from qlib.rl.simulator import ActType, Simulator, StateType +from qlib.rl.interpreter import ActionInterpreter +from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig +from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor +from qlib.rl.order_execution.from_neutrader.feature import init_qlib +from qlib.rl.order_execution.from_neutrader.state import SAOEEpisodicState +from qlib.rl.order_execution.from_neutrader.strategy import DecomposedStrategy +from qlib.rl.simulator import Simulator from qlib.strategy.base import BaseStrategy -@dataclass -class ExchangeConfig: - limit_threshold: Union[float, Tuple[str, str]] - deal_price: Union[str, Tuple[str, str]] - volume_threshold: Union[float, Dict[str, Tuple[str, str]]] - open_cost: float = 0.0005 - close_cost: float = 0.0015 - min_cost: float = 5. - trade_unit: Optional[float] = 100. - cash_limit: Optional[Union[Path, float]] = None - generate_report: bool = False - - def get_common_infra( config: ExchangeConfig, trade_start_time: pd.Timestamp, @@ -69,13 +62,31 @@ def get_common_infra( return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) -class QlibSimulator(Simulator[Order, StateType, ActType]): +class CategoricalActionInterpreter(ActionInterpreter[SAOEEpisodicState, int, float]): + def __init__(self, values: Union[int, List[float]]) -> None: + if isinstance(values, int): + values = [i / values for i in range(0, values + 1)] + self.action_values = values + + @property + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(len(self.action_values)) + + def interpret(self, state: SAOEEpisodicState, action: int) -> float: + volume = min(state.position, state.target * self.action_values[action]) + if state.cur_step + 1 >= state.num_step: + volume = state.position # execute all volumes at last + return volume + + +class QlibSimulator(Simulator[Order, SAOEEpisodicState, float]): def __init__( self, - qlib_config: QlibConfig, time_per_step: str, - top_strategy: BaseStrategy, - inner_strategy_fn: Callable[[], BaseStrategy], + start_time: str, + end_time: str, + qlib_config: QlibConfig, + top_strategy_fn: Callable[[CommonInfrastructure, Order, TradeRange, str], BaseStrategy], inner_executor_fn: Callable[[CommonInfrastructure], BaseExecutor], exchange_config: ExchangeConfig, ) -> None: @@ -83,61 +94,77 @@ class QlibSimulator(Simulator[Order, StateType, ActType]): initial=None, # TODO ) - self.qlib_config = qlib_config - + self._trade_range = TradeRangeByTime(start_time, end_time) + self._qlib_config = qlib_config self._time_per_step = time_per_step - self._top_strategy = top_strategy + self._top_strategy_fn = top_strategy_fn self._inner_executor_fn = inner_executor_fn - self._inner_strategy_fn = inner_strategy_fn self._exchange_config = exchange_config - self._executor: Optional[NestedExecutor] = None + self._executor: Optional[RLNestedExecutor] = None self._collect_data_loop: Optional[Generator] = None self._done = False - def _reset( + self._inner_strategy = DecomposedStrategy() + + def reset( self, - instrument: str, - date_time: pd.Timestamp, + order: Order, + instrument: str = "SH600000", # TODO: Test only. Remove this default value later. ) -> None: - # TODO: init_qlib + init_qlib(self._qlib_config, instrument) common_infra = get_common_infra( self._exchange_config, - trade_start_time=date_time, - trade_end_time=date_time, + trade_start_time=order.start_time, + trade_end_time=order.end_time, codes=[instrument], ) - self._executor = NestedExecutor( + self._executor = RLNestedExecutor( time_per_step=self._time_per_step, inner_executor=self._inner_executor_fn(common_infra), - inner_strategy=self._inner_strategy_fn(), + inner_strategy=self._inner_strategy, track_data=True, + common_infra=common_infra, ) - self._executor.reset(start_time=date_time, end_time=date_time) - self._top_strategy.reset(level_infra=self._executor.get_level_infra()) - self._collect_data_loop = self._executor.collect_data(self._top_strategy.generate_trade_decision(), level=0) + top_strategy = self._top_strategy_fn(common_infra, order, self._trade_range, instrument) + + self._executor.reset(start_time=order.start_time, end_time=order.end_time) + top_strategy.reset(level_infra=self._executor.get_level_infra()) + + self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0) assert isinstance(self._collect_data_loop, Generator) + strategy = self._iter_strategy(action=None) + sample, ep_state = strategy.sample_state_pair + self._last_ep_state = ep_state + self._done = False - def step(self, action: ActType) -> None: + def _iter_strategy(self, action: float = None) -> DecomposedStrategy: + strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + while not isinstance(strategy, DecomposedStrategy): + strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + assert isinstance(strategy, DecomposedStrategy) + return strategy + + def step(self, action: float) -> None: try: - strategy = self._collect_data_loop.send(action) - while not isinstance(strategy, BaseStrategy): - strategy = self._collect_data_loop.send(action) - assert isinstance(strategy, BaseStrategy) - - # TODO: do something here - + strategy = self._iter_strategy(action=action) + sample, ep_state = strategy.sample_state_pair except StopIteration: + sample, ep_state = self._inner_strategy.sample_state_pair + assert ep_state.done + + self._last_ep_state = ep_state + if ep_state.done: self._done = True - def get_state(self) -> StateType: - pass # TODO: Collect info from executor. Generate state. + def get_state(self) -> SAOEEpisodicState: + return self._last_ep_state def done(self) -> bool: return self._done diff --git a/qlib/rl/order_execution/tests/__init__.py b/qlib/rl/order_execution/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/rl/order_execution/tests/test_simulator_qlib.py b/qlib/rl/order_execution/tests/test_simulator_qlib.py new file mode 100644 index 000000000..6e93bc059 --- /dev/null +++ b/qlib/rl/order_execution/tests/test_simulator_qlib.py @@ -0,0 +1,133 @@ +import collections +from pathlib import Path + +import pandas as pd + +from qlib.backtest.decision import Order, OrderDir, TradeRange +from qlib.backtest.executor import SimulatorExecutor +from qlib.backtest.utils import CommonInfrastructure +from qlib.config import QlibConfig +from qlib.contrib.strategy import TWAPStrategy +from qlib.rl.order_execution.from_neutrader.executor import RLNestedExecutor +from qlib.rl.order_execution.from_neutrader.strategy import SingleOrderStrategy +from qlib.rl.order_execution.simulator_qlib import CategoricalActionInterpreter, ExchangeConfig, QlibSimulator + +qlib_config = QlibConfig( + { + "provider_uri_day": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1d"), + "provider_uri_1min": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_1min"), + "feature_root_dir": Path("C:/workspace/NeuTrader/data_sample/cn/qlib_amc_handler_stock"), + "feature_columns_today": [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5", + ], + "feature_columns_yesterday": [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1", + ], + } +) + +exchange_config = ExchangeConfig( + limit_threshold=('$ask == 0', '$bid == 0'), + deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'), + volume_threshold={ + 'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"), + 'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1') + }, + open_cost=0.0005, + close_cost=0.0015, + min_cost=5.0, + trade_unit=None, + cash_limit=None, + generate_report=False, +) + + +def _top_strategy_fn( + common_infra: CommonInfrastructure, + order: Order, + trade_range: TradeRange, + instrument: str, +) -> SingleOrderStrategy: + return SingleOrderStrategy(common_infra, order, trade_range, instrument) + + +def _inner_executor_fn(common_infra: CommonInfrastructure) -> RLNestedExecutor: + return RLNestedExecutor( + time_per_step="30min", + inner_strategy=TWAPStrategy(), + inner_executor=SimulatorExecutor( + time_per_step="1min", + verbose=False, + trade_type=SimulatorExecutor.TT_SERIAL, + generate_report=False, + common_infra=common_infra, + track_data=True, + ), + common_infra=common_infra, + track_data=True, + ) + + +def test(): + order_infos = [ + ("2019-03-04", 1078.644160270691, 1), + ("2019-03-11", 32.440425872802734, 1), + ("2019-03-25", 40.55053234100342, 0), + ("2019-04-01", 1070.5340538024902, 0), + ("2019-05-27", 300.0739393234253, 1), + ("2019-06-03", 8.110106468200684, 0), + ("2019-06-11", 0.9360466003417968, 0), + ("2019-06-17", 794.4272003173828, 1), + ("2019-06-24", 7.865615844726562, 0), + ("2019-07-01", 1077.589370727539, 0), + ("2021-01-04", 499.7846999168396, 1), + ("2021-01-11", 14.918946266174316, 0), + ("2021-01-18", 484.8657536506653, 0), + ("2021-02-08", 537.0820655822754, 1), + ("2021-02-18", 7.459473133087158, 0), + ("2021-02-22", 7.459473133087158, 0), + ("2021-03-01", 14.918946266174316, 1), + ("2021-03-08", 872.7583565711975, 1), + + ] + orders = collections.deque([ + Order( + stock_id="", + amount=info[1], + direction=OrderDir(info[2]), + start_time=pd.Timestamp(info[0]), + end_time=pd.Timestamp(info[0]), + ) + for info in order_infos + ]) + + # fmt: off + simulator = QlibSimulator( + time_per_step="1day", + start_time="9:45", + end_time="14:44", + qlib_config=qlib_config, + top_strategy_fn=_top_strategy_fn, + inner_executor_fn=_inner_executor_fn, + exchange_config=exchange_config, + ) + # fmt: on + + action_interpreter = CategoricalActionInterpreter(values=4) + + simulator.reset(orders.popleft()) + + for i in range(10): + print(f"Step {i}") + ep_state = simulator.get_state() + action = action_interpreter(ep_state, 1) + + simulator.step(action) + if simulator.done(): + break + + +if __name__ == "__main__": + test()