mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
RL backtest with simulator (#1299)
* RL backtest with simulator * Minor modification in init_qlib * Cherry pick PR 1302 * Resolve PR comments * Fix missing data processing * Minor bugfix * Add TODOs and docs * Add a comment
This commit is contained in:
@@ -576,3 +576,18 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
f"trade_range: {self.trade_range}; "
|
||||
f"order_list[{len(self.order_list)}]"
|
||||
)
|
||||
|
||||
|
||||
class TradeDecisionWithDetails(TradeDecisionWO):
|
||||
"""Decision with detail information. Detail information is used to generate execution reports.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Optional[Tuple[int, int]] = None,
|
||||
details: Optional[Any] = None,
|
||||
) -> None:
|
||||
super().__init__(order_list, strategy, trade_range)
|
||||
|
||||
self.details = details
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import pickle
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -14,12 +15,13 @@ import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import TradeRangeByTime
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
|
||||
@@ -41,7 +43,7 @@ def _get_multi_level_executor_config(
|
||||
}
|
||||
|
||||
freqs = list(strategy_config.keys())
|
||||
freqs.sort(key=lambda x: pd.Timedelta(x))
|
||||
freqs.sort(key=pd.Timedelta)
|
||||
for freq in freqs:
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
@@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
|
||||
# HACK: for qlib v0.8
|
||||
value_dict = value_dict.to_series()
|
||||
try:
|
||||
value_dict = {k: v for k, v in value_dict.items()}
|
||||
value_dict = copy.deepcopy(value_dict)
|
||||
if value_dict["ffr"].empty:
|
||||
continue
|
||||
except Exception:
|
||||
@@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
|
||||
return records
|
||||
|
||||
|
||||
def _generate_report(decisions: list, report_dict: dict) -> dict:
|
||||
# TODO: there should be richer annotation for the input (e.g. report) and the returned report
|
||||
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
|
||||
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
|
||||
"""Generate backtest reports
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decisions:
|
||||
List of trade decisions.
|
||||
report_indicators
|
||||
List of indicator reports.
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
indicator_dict = defaultdict(list)
|
||||
indicator_his = defaultdict(list)
|
||||
for report_indicator in report_indicators:
|
||||
for key, value in report_indicator.items():
|
||||
if key.endswith("_obj"):
|
||||
indicator_his[key].append(value.order_indicator_his)
|
||||
else:
|
||||
indicator_dict[key].append(value)
|
||||
|
||||
report = {}
|
||||
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
|
||||
for key in ["1minute", "5minute", "30minute", "1day"]:
|
||||
if key not in report_dict["indicator"]:
|
||||
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
|
||||
for key in ["1min", "5min", "30min", "1day"]:
|
||||
if key not in indicator_dict:
|
||||
continue
|
||||
report[key] = report_dict["indicator"][key]
|
||||
report[key + "_obj"] = _convert_indicator_to_dataframe(
|
||||
report_dict["indicator"][key + "_obj"].order_indicator_his
|
||||
)
|
||||
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
|
||||
|
||||
report[key] = pd.concat(indicator_dict[key])
|
||||
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])
|
||||
|
||||
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
|
||||
if len(cur_details) > 0:
|
||||
cur_details.pop("freq")
|
||||
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
|
||||
if "1minute" in report_dict["report"]:
|
||||
report["simulator"] = report_dict["report"]["1minute"][0]
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def single(
|
||||
def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: str = "stock",
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
A new simulator will be created and used for every single-day order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
reports = []
|
||||
decisions = []
|
||||
for _, row in orders.iterrows():
|
||||
date = pd.Timestamp(row["datetime"])
|
||||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(row["direction"]),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "1min",
|
||||
}
|
||||
)
|
||||
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with collect_data_loop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
@@ -127,7 +274,7 @@ def single(
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
top_strategy_config = {
|
||||
strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
@@ -139,14 +286,14 @@ def single(
|
||||
},
|
||||
}
|
||||
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
|
||||
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
|
||||
tmp_backtest_config.update(
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "1min",
|
||||
@@ -156,11 +303,11 @@ def single(
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=tmp_backtest_config,
|
||||
exchange_kwargs=exchange_config,
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
_set_env_for_all_strategy(executor=executor)
|
||||
@@ -172,7 +319,7 @@ def single(
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
report = _generate_report(decisions, report_dict)
|
||||
report = _generate_report(decisions, [report_dict["indicator"]])
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: report}
|
||||
@@ -184,7 +331,7 @@ def single(
|
||||
return records
|
||||
|
||||
|
||||
def backtest(backtest_config: dict) -> pd.DataFrame:
|
||||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
@@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
res = Parallel(**mp_config)(
|
||||
@@ -227,5 +375,12 @@ if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
path = sys.argv[1]
|
||||
backtest(get_backtest_config_fromfile(path))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend")
|
||||
args = parser.parse_args()
|
||||
|
||||
backtest(
|
||||
backtest_config=get_backtest_config_fromfile(args.config_path),
|
||||
with_simulator=args.use_simulator,
|
||||
)
|
||||
|
||||
@@ -53,7 +53,8 @@ def parse_backtest_config(path: str) -> dict:
|
||||
|
||||
del sys.modules[tmp_module_name]
|
||||
else:
|
||||
config = yaml.safe_load(open(tmp_config_file.name))
|
||||
with open(tmp_config_file.name) as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
if "_base_" in config:
|
||||
base_file_name = config.pop("_base_")
|
||||
|
||||
@@ -81,10 +81,12 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
|
||||
def _convert_to_path(path: str | Path) -> Path:
|
||||
return path if isinstance(path, Path) else Path(path)
|
||||
|
||||
provider_uri_map = {
|
||||
"day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(),
|
||||
"1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(),
|
||||
}
|
||||
provider_uri_map = {}
|
||||
if "provider_uri_day" in qlib_config:
|
||||
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
|
||||
if "provider_uri_1min" in qlib_config:
|
||||
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
|
||||
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
|
||||
@@ -9,12 +9,11 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T, ONE_DAY
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
from qlib.utils.index_data import IndexData
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
from ...data import D
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -82,18 +81,20 @@ def load_backtest_data(
|
||||
trade_exchange: Exchange,
|
||||
trade_range: TradeRange,
|
||||
) -> IntradayBacktestData:
|
||||
data = cast(
|
||||
IndexData,
|
||||
trade_exchange.get_deal_price(
|
||||
stock_id=order.stock_id,
|
||||
start_time=order.date,
|
||||
end_time=order.date + ONE_DAY - EPS_T,
|
||||
direction=order.direction,
|
||||
method=None,
|
||||
),
|
||||
# TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
|
||||
tmp_data = D.features(
|
||||
trade_exchange.codes,
|
||||
trade_exchange.all_fields,
|
||||
trade_exchange.start_time,
|
||||
trade_exchange.end_time,
|
||||
freq=trade_exchange.freq,
|
||||
disk_cache=True,
|
||||
)
|
||||
|
||||
ticks_index = pd.DatetimeIndex(data.index)
|
||||
ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"])
|
||||
ticks_index = ticks_index[order.start_time <= ticks_index]
|
||||
ticks_index = ticks_index[ticks_index <= order.end_time]
|
||||
|
||||
if isinstance(trade_range, TradeRangeByTime):
|
||||
ticks_for_order = get_ticks_slice(
|
||||
ticks_index,
|
||||
@@ -122,7 +123,10 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
date: pd.Timestamp,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"])
|
||||
df = df.reset_index()
|
||||
if "instrument" in df.columns:
|
||||
df = df.drop(columns=["instrument"])
|
||||
return df.set_index(["datetime"])
|
||||
|
||||
self.today = _drop_stock_id(fetch_features(stock_id, date))
|
||||
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
|
||||
|
||||
@@ -91,7 +91,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
@@ -99,7 +99,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
) -> None:
|
||||
super(SimpleIntradayBacktestData, self).__init__()
|
||||
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# No longer need for pandas >= 1.4
|
||||
@@ -154,13 +154,13 @@ class IntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
@@ -163,6 +163,12 @@ def auto_device(module: nn.Module) -> torch.device:
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
|
||||
# TODO: this should be handled by whoever calls load_weight.
|
||||
# TODO: For example, when the outer class receives a weight, it should first unpack it,
|
||||
# TODO: and send the corresponding part to individual component.
|
||||
if "vessel" in loaded_weight:
|
||||
loaded_weight = loaded_weight["vessel"]["policy"]
|
||||
try:
|
||||
policy.load_state_dict(loaded_weight)
|
||||
except RuntimeError:
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import Order
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.simulator import Simulator
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.simulator import Simulator
|
||||
from .state import SAOEState, SAOEStateAdapter
|
||||
from .strategy import SAOEStrategy
|
||||
from ..utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
@@ -23,30 +24,42 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
----------
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
strategy_config
|
||||
Strategy configuration
|
||||
executor_config
|
||||
Executor configuration
|
||||
exchange_config
|
||||
Exchange configuration
|
||||
qlib_config
|
||||
Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.
|
||||
cash_limit:
|
||||
Cash limit.
|
||||
backtest_mode
|
||||
Whether the simulator is under backtest mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
strategy_config: dict,
|
||||
executor_config: dict,
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
cash_limit: Optional[float] = None,
|
||||
backtest_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
|
||||
|
||||
strategy_config = {
|
||||
"class": "SingleOrderStrategy",
|
||||
"module_path": "qlib.rl.strategy.single_order",
|
||||
"kwargs": {
|
||||
"order": order,
|
||||
"trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()),
|
||||
},
|
||||
}
|
||||
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config)
|
||||
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
@@ -55,6 +68,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
executor_config: dict,
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
cash_limit: Optional[float] = None,
|
||||
backtest_mode: bool = False,
|
||||
) -> None:
|
||||
if qlib_config is not None:
|
||||
init_qlib(qlib_config, part="skip")
|
||||
@@ -65,22 +80,35 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=1e12,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
pos_type="InfPosition",
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
|
||||
assert isinstance(self._executor, NestedExecutor)
|
||||
|
||||
self.report_dict: dict = {}
|
||||
self.decisions: List[BaseTradeDecision] = []
|
||||
self._collect_data_loop = collect_data_loop(
|
||||
start_time=order.date,
|
||||
end_time=order.date,
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
self._last_yielded_saoe_strategy = self._iter_strategy(action=None)
|
||||
# TODO: backtest_mode is not a necessary parameter if we carefully design it.
|
||||
# TODO: It should disappear with CollectDataEnvWrapper in the future.
|
||||
if backtest_mode:
|
||||
executor: BaseExecutor = self._executor
|
||||
while isinstance(executor, NestedExecutor):
|
||||
if hasattr(executor.inner_strategy, "set_env"):
|
||||
executor.inner_strategy.set_env(CollectDataEnvWrapper())
|
||||
executor = executor.inner_executor
|
||||
|
||||
# Call `step()` with None action to initialize the internal generator.
|
||||
self.step(action=None)
|
||||
|
||||
self._order = order
|
||||
|
||||
@@ -91,17 +119,19 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
def twap_price(self) -> float:
|
||||
return self._get_adapter().twap_price
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> SAOEStrategy:
|
||||
def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield SAOEStrategy."""
|
||||
assert self._collect_data_loop is not None
|
||||
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, SAOEStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, SAOEStrategy)
|
||||
return strategy
|
||||
obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(obj, SAOEStrategy):
|
||||
if isinstance(obj, BaseTradeDecision):
|
||||
self.decisions.append(obj)
|
||||
obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(obj, SAOEStrategy)
|
||||
return obj
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
def step(self, action: Optional[float]) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import cast, NamedTuple, Optional, Tuple
|
||||
from typing import cast, Callable, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -13,6 +13,7 @@ from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.constant import EPS, ONE_MIN, REG_CN
|
||||
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
|
||||
from qlib.typehint import TypedDict
|
||||
from qlib.utils.index_data import IndexData
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
@@ -38,6 +39,37 @@ def _get_all_timestamps(
|
||||
return pd.DatetimeIndex(ret)
|
||||
|
||||
|
||||
def fill_missing_data(
|
||||
original_data: np.ndarray,
|
||||
total_time_list: List[pd.Timestamp],
|
||||
found_time_list: List[pd.Timestamp],
|
||||
fill_method: Callable = np.median,
|
||||
) -> np.ndarray:
|
||||
"""Fill missing data. We need this function to deal with data that have missing values in some minutes.
|
||||
|
||||
TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
original_data
|
||||
Original data without missing values.
|
||||
total_time_list
|
||||
All timestamps that required.
|
||||
found_time_list
|
||||
Timestamps found in the original data.
|
||||
fill_method
|
||||
Method used to fill the missing data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The filled data.
|
||||
"""
|
||||
assert len(original_data) == len(found_time_list)
|
||||
tmp = dict(zip(found_time_list, original_data))
|
||||
fill_val = fill_method(original_data)
|
||||
return np.array([tmp.get(t, fill_val) for t in total_time_list])
|
||||
|
||||
|
||||
class SAOEStateAdapter:
|
||||
"""
|
||||
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
|
||||
@@ -106,16 +138,17 @@ class SAOEStateAdapter:
|
||||
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
|
||||
exec_vol *= self.position / (exec_vol.sum())
|
||||
|
||||
market_volume = np.array(
|
||||
market_volume = cast(
|
||||
IndexData,
|
||||
self.exchange.get_volume(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
),
|
||||
).reshape(-1)
|
||||
|
||||
market_price = np.array(
|
||||
)
|
||||
market_price = cast(
|
||||
IndexData,
|
||||
self.exchange.get_deal_price(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
@@ -123,7 +156,11 @@ class SAOEStateAdapter:
|
||||
method=None,
|
||||
direction=self.order.direction,
|
||||
),
|
||||
).reshape(-1)
|
||||
)
|
||||
found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)]
|
||||
total_time_list = _get_all_timestamps(start_time, end_time)
|
||||
market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list)
|
||||
market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list)
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
|
||||
@@ -5,15 +5,16 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from types import GeneratorType
|
||||
from typing import Any, cast, Dict, Generator, Optional, Union
|
||||
from typing import Any, cast, Dict, Generator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, Order
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import LevelInfrastructure
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.native import load_backtest_data
|
||||
@@ -235,6 +236,23 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
if self._backtest:
|
||||
self._env.reset()
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
|
||||
trade_details = []
|
||||
for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, "order_list")):
|
||||
trade_details.append(
|
||||
{
|
||||
"instrument": o.stock_id,
|
||||
"datetime": self.trade_calendar.get_step_time()[0],
|
||||
"freq": self.trade_calendar.get_freq(),
|
||||
"rl_exec_vol": v,
|
||||
}
|
||||
)
|
||||
if a is not None:
|
||||
trade_details[-1]["rl_action"] = a
|
||||
return pd.DataFrame.from_records(trade_details)
|
||||
|
||||
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
states = []
|
||||
obs_batch = []
|
||||
@@ -261,4 +279,8 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWO(order_list=order_list, strategy=self)
|
||||
return TradeDecisionWithDetails(
|
||||
order_list=order_list,
|
||||
strategy=self,
|
||||
details=self._generate_trade_details(act, exec_vols),
|
||||
)
|
||||
|
||||
@@ -32,16 +32,7 @@ def get_order() -> Order:
|
||||
)
|
||||
|
||||
|
||||
def get_configs(order: Order) -> Tuple[dict, dict, dict]:
|
||||
strategy_config = {
|
||||
"class": "SingleOrderStrategy",
|
||||
"module_path": "qlib.rl.strategy.single_order",
|
||||
"kwargs": {
|
||||
"order": order,
|
||||
"trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()),
|
||||
},
|
||||
}
|
||||
|
||||
def get_configs(order: Order) -> Tuple[dict, dict]:
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
@@ -93,7 +84,7 @@ def get_configs(order: Order) -> Tuple[dict, dict, dict]:
|
||||
"trade_unit": None,
|
||||
}
|
||||
|
||||
return strategy_config, executor_config, exchange_config
|
||||
return executor_config, exchange_config
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecution:
|
||||
@@ -115,12 +106,11 @@ def get_simulator(order: Order) -> SingleAssetOrderExecution:
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
strategy_config, executor_config, exchange_config = get_configs(order)
|
||||
executor_config, exchange_config = get_configs(order)
|
||||
|
||||
return SingleAssetOrderExecution(
|
||||
order=order,
|
||||
qlib_config=qlib_config,
|
||||
strategy_config=strategy_config,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user