1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

RL backtest with simulator (#1299)

* RL backtest with simulator

* Minor modification in init_qlib

* Cherry pick PR 1302

* Resolve PR comments

* Fix missing data processing

* Minor bugfix

* Add TODOs and docs

* Add a comment
This commit is contained in:
Huoran Li
2022-10-12 16:44:28 +08:00
committed by GitHub
parent 54928e956d
commit 216a8ec2de
11 changed files with 354 additions and 92 deletions

View File

@@ -576,3 +576,18 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
f"trade_range: {self.trade_range}; "
f"order_list[{len(self.order_list)}]"
)
class TradeDecisionWithDetails(TradeDecisionWO):
"""Decision with detail information. Detail information is used to generate execution reports.
"""
def __init__(
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Optional[Tuple[int, int]] = None,
details: Optional[Any] = None,
) -> None:
super().__init__(order_list, strategy, trade_range)
self.details = details

View File

@@ -2,11 +2,12 @@
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import copy
import pickle
import sys
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import pandas as pd
@@ -14,12 +15,13 @@ import torch
from joblib import Parallel, delayed
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import TradeRangeByTime
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
@@ -41,7 +43,7 @@ def _get_multi_level_executor_config(
}
freqs = list(strategy_config.keys())
freqs.sort(key=lambda x: pd.Timedelta(x))
freqs.sort(key=pd.Timedelta)
for freq in freqs:
executor_config = {
"class": "NestedExecutor",
@@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
# HACK: for qlib v0.8
value_dict = value_dict.to_series()
try:
value_dict = {k: v for k, v in value_dict.items()}
value_dict = copy.deepcopy(value_dict)
if value_dict["ffr"].empty:
continue
except Exception:
@@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
return records
def _generate_report(decisions: list, report_dict: dict) -> dict:
# TODO: there should be richer annotation for the input (e.g. report) and the returned report
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
"""Generate backtest reports
Parameters
----------
decisions:
List of trade decisions.
report_indicators
List of indicator reports.
Returns
-------
"""
indicator_dict = defaultdict(list)
indicator_his = defaultdict(list)
for report_indicator in report_indicators:
for key, value in report_indicator.items():
if key.endswith("_obj"):
indicator_his[key].append(value.order_indicator_his)
else:
indicator_dict[key].append(value)
report = {}
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
for key in ["1minute", "5minute", "30minute", "1day"]:
if key not in report_dict["indicator"]:
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
for key in ["1min", "5min", "30min", "1day"]:
if key not in indicator_dict:
continue
report[key] = report_dict["indicator"][key]
report[key + "_obj"] = _convert_indicator_to_dataframe(
report_dict["indicator"][key + "_obj"].order_indicator_his
)
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
report[key] = pd.concat(indicator_dict[key])
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
if len(cur_details) > 0:
cur_details.pop("freq")
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
if "1minute" in report_dict["report"]:
report["simulator"] = report_dict["report"]["1minute"][0]
return report
def single(
def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: str = "stock",
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
A new simulator will be created and used for every single-day order.
Parameters
----------
backtest_config:
Backtest config
orders:
Orders to be executed. Example format:
datetime instrument amount direction
0 2020-06-01 INST 600.0 0
1 2020-06-02 INST 700.0 1
...
split
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
cash_limit
Limitation of cash.
generate_report
Whether to generate reports.
Returns
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
stocks = orders.instrument.unique().tolist()
reports = []
decisions = []
for _, row in orders.iterrows():
date = pd.Timestamp(row["datetime"])
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(row["direction"]),
start_time=start_time,
end_time=end_time,
)
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
}
)
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config=exchange_config,
qlib_config=None,
cash_limit=None,
backtest_mode=True,
)
reports.append(simulator.report_dict)
decisions += simulator.decisions
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
report = _generate_report(decisions, [report["indicator"] for report in reports])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
else:
day = orders.iloc[0].datetime
report = {day: report}
return records, report
else:
return records
def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with collect_data_loop.
Parameters
----------
backtest_config:
Backtest config
orders:
Orders to be executed. Example format:
datetime instrument amount direction
0 2020-06-01 INST 600.0 0
1 2020-06-02 INST 700.0 1
...
split
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
cash_limit
Limitation of cash.
generate_report
Whether to generate reports.
Returns
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
@@ -127,7 +274,7 @@ def single(
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()
top_strategy_config = {
strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
@@ -139,14 +286,14 @@ def single(
},
}
top_executor_config = _get_multi_level_executor_config(
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
)
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
tmp_backtest_config.update(
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
@@ -156,11 +303,11 @@ def single(
strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
strategy=top_strategy_config,
executor=top_executor_config,
strategy=strategy_config,
executor=executor_config,
benchmark=None,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=tmp_backtest_config,
exchange_kwargs=exchange_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
_set_env_for_all_strategy(executor=executor)
@@ -172,7 +319,7 @@ def single(
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
report = _generate_report(decisions, report_dict)
report = _generate_report(decisions, [report_dict["indicator"]])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
@@ -184,7 +331,7 @@ def single(
return records
def backtest(backtest_config: dict) -> pd.DataFrame:
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
order_df = read_order_file(backtest_config["order_file"])
cash_limit = backtest_config["exchange"].pop("cash_limit")
@@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
single = single_with_simulator if with_simulator else single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
res = Parallel(**mp_config)(
@@ -227,5 +375,12 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
path = sys.argv[1]
backtest(get_backtest_config_fromfile(path))
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend")
args = parser.parse_args()
backtest(
backtest_config=get_backtest_config_fromfile(args.config_path),
with_simulator=args.use_simulator,
)

View File

@@ -53,7 +53,8 @@ def parse_backtest_config(path: str) -> dict:
del sys.modules[tmp_module_name]
else:
config = yaml.safe_load(open(tmp_config_file.name))
with open(tmp_config_file.name) as input_stream:
config = yaml.safe_load(input_stream)
if "_base_" in config:
base_file_name = config.pop("_base_")

View File

@@ -81,10 +81,12 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
def _convert_to_path(path: str | Path) -> Path:
return path if isinstance(path, Path) else Path(path)
provider_uri_map = {
"day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(),
"1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(),
}
provider_uri_map = {}
if "provider_uri_day" in qlib_config:
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
if "provider_uri_1min" in qlib_config:
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
qlib.init(
region=REG_CN,
auto_mount=False,

View File

@@ -9,12 +9,11 @@ import pandas as pd
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.constant import EPS_T, ONE_DAY
from qlib.rl.order_execution.utils import get_ticks_slice
from qlib.utils.index_data import IndexData
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
from ...data import D
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -82,18 +81,20 @@ def load_backtest_data(
trade_exchange: Exchange,
trade_range: TradeRange,
) -> IntradayBacktestData:
data = cast(
IndexData,
trade_exchange.get_deal_price(
stock_id=order.stock_id,
start_time=order.date,
end_time=order.date + ONE_DAY - EPS_T,
direction=order.direction,
method=None,
),
# TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
tmp_data = D.features(
trade_exchange.codes,
trade_exchange.all_fields,
trade_exchange.start_time,
trade_exchange.end_time,
freq=trade_exchange.freq,
disk_cache=True,
)
ticks_index = pd.DatetimeIndex(data.index)
ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"])
ticks_index = ticks_index[order.start_time <= ticks_index]
ticks_index = ticks_index[ticks_index <= order.end_time]
if isinstance(trade_range, TradeRangeByTime):
ticks_for_order = get_ticks_slice(
ticks_index,
@@ -122,7 +123,10 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
date: pd.Timestamp,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"])
df = df.reset_index()
if "instrument" in df.columns:
df = df.drop(columns=["instrument"])
return df.set_index(["datetime"])
self.today = _drop_stock_id(fetch_features(stock_id, date))
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))

View File

@@ -91,7 +91,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
def __init__(
self,
data_dir: Path,
data_dir: Path | str,
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
@@ -99,7 +99,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
) -> None:
super(SimpleIntradayBacktestData, self).__init__()
backtest = _read_pickle(data_dir / stock_id)
backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
# No longer need for pandas >= 1.4
@@ -154,13 +154,13 @@ class IntradayProcessedData(BaseIntradayProcessedData):
def __init__(
self,
data_dir: Path,
data_dir: Path | str,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> None:
proc = _read_pickle(data_dir / stock_id)
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)

View File

@@ -163,6 +163,12 @@ def auto_device(module: nn.Module) -> torch.device:
def load_weight(policy: nn.Module, path: Path) -> None:
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
# TODO: this should be handled by whoever calls load_weight.
# TODO: For example, when the outer class receives a weight, it should first unpack it,
# TODO: and send the corresponding part to individual component.
if "vessel" in loaded_weight:
loaded_weight = loaded_weight["vessel"]["policy"]
try:
policy.load_state_dict(loaded_weight)
except RuntimeError:

View File

@@ -3,17 +3,18 @@
from __future__ import annotations
from typing import Generator, Optional
from typing import Generator, List, Optional
import pandas as pd
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import Order
from qlib.backtest.executor import NestedExecutor
from qlib.rl.simulator import Simulator
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor
from qlib.rl.data.integration import init_qlib
from qlib.rl.simulator import Simulator
from .state import SAOEState, SAOEStateAdapter
from .strategy import SAOEStrategy
from ..utils.env_wrapper import CollectDataEnvWrapper
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
@@ -23,30 +24,42 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
strategy_config
Strategy configuration
executor_config
Executor configuration
exchange_config
Exchange configuration
qlib_config
Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.
cash_limit:
Cash limit.
backtest_mode
Whether the simulator is under backtest mode.
"""
def __init__(
self,
order: Order,
strategy_config: dict,
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
backtest_mode: bool = False,
) -> None:
super().__init__(initial=order)
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
strategy_config = {
"class": "SingleOrderStrategy",
"module_path": "qlib.rl.strategy.single_order",
"kwargs": {
"order": order,
"trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()),
},
}
self._collect_data_loop: Optional[Generator] = None
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config)
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode)
def reset(
self,
@@ -55,6 +68,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
backtest_mode: bool = False,
) -> None:
if qlib_config is not None:
init_qlib(qlib_config, part="skip")
@@ -65,22 +80,35 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
strategy=strategy_config,
executor=executor_config,
benchmark=order.stock_id,
account=1e12,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=exchange_config,
pos_type="InfPosition",
pos_type="Position" if cash_limit is not None else "InfPosition",
)
assert isinstance(self._executor, NestedExecutor)
self.report_dict: dict = {}
self.decisions: List[BaseTradeDecision] = []
self._collect_data_loop = collect_data_loop(
start_time=order.date,
end_time=order.date,
trade_strategy=strategy,
trade_executor=self._executor,
return_value=self.report_dict,
)
assert isinstance(self._collect_data_loop, Generator)
self._last_yielded_saoe_strategy = self._iter_strategy(action=None)
# TODO: backtest_mode is not a necessary parameter if we carefully design it.
# TODO: It should disappear with CollectDataEnvWrapper in the future.
if backtest_mode:
executor: BaseExecutor = self._executor
while isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
executor.inner_strategy.set_env(CollectDataEnvWrapper())
executor = executor.inner_executor
# Call `step()` with None action to initialize the internal generator.
self.step(action=None)
self._order = order
@@ -91,17 +119,19 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
def twap_price(self) -> float:
return self._get_adapter().twap_price
def _iter_strategy(self, action: float = None) -> SAOEStrategy:
def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy:
"""Iterate the _collect_data_loop until we get the next yield SAOEStrategy."""
assert self._collect_data_loop is not None
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
while not isinstance(strategy, SAOEStrategy):
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
assert isinstance(strategy, SAOEStrategy)
return strategy
obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
while not isinstance(obj, SAOEStrategy):
if isinstance(obj, BaseTradeDecision):
self.decisions.append(obj)
obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
assert isinstance(obj, SAOEStrategy)
return obj
def step(self, action: float) -> None:
def step(self, action: Optional[float]) -> None:
"""Execute one step or SAOE.
Parameters

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import typing
from typing import cast, NamedTuple, Optional, Tuple
from typing import cast, Callable, List, NamedTuple, Optional, Tuple
import numpy as np
import pandas as pd
@@ -13,6 +13,7 @@ from qlib.backtest.executor import BaseExecutor
from qlib.constant import EPS, ONE_MIN, REG_CN
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
from qlib.typehint import TypedDict
from qlib.utils.index_data import IndexData
from qlib.utils.time import get_day_min_idx_range
if typing.TYPE_CHECKING:
@@ -38,6 +39,37 @@ def _get_all_timestamps(
return pd.DatetimeIndex(ret)
def fill_missing_data(
original_data: np.ndarray,
total_time_list: List[pd.Timestamp],
found_time_list: List[pd.Timestamp],
fill_method: Callable = np.median,
) -> np.ndarray:
"""Fill missing data. We need this function to deal with data that have missing values in some minutes.
TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
Parameters
----------
original_data
Original data without missing values.
total_time_list
All timestamps that required.
found_time_list
Timestamps found in the original data.
fill_method
Method used to fill the missing data.
Returns
-------
The filled data.
"""
assert len(original_data) == len(found_time_list)
tmp = dict(zip(found_time_list, original_data))
fill_val = fill_method(original_data)
return np.array([tmp.get(t, fill_val) for t in total_time_list])
class SAOEStateAdapter:
"""
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
@@ -106,16 +138,17 @@ class SAOEStateAdapter:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
exec_vol *= self.position / (exec_vol.sum())
market_volume = np.array(
market_volume = cast(
IndexData,
self.exchange.get_volume(
self.order.stock_id,
pd.Timestamp(start_time),
pd.Timestamp(end_time),
method=None,
),
).reshape(-1)
market_price = np.array(
)
market_price = cast(
IndexData,
self.exchange.get_deal_price(
self.order.stock_id,
pd.Timestamp(start_time),
@@ -123,7 +156,11 @@ class SAOEStateAdapter:
method=None,
direction=self.order.direction,
),
).reshape(-1)
)
found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)]
total_time_list = _get_all_timestamps(start_time, end_time)
market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list)
market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list)
assert market_price.shape == market_volume.shape == exec_vol.shape

View File

@@ -5,15 +5,16 @@ from __future__ import annotations
import collections
from types import GeneratorType
from typing import Any, cast, Dict, Generator, Optional, Union
from typing import Any, cast, Dict, Generator, List, Optional, Union
import numpy as np
import pandas as pd
import torch
from tianshou.data import Batch
from tianshou.policy import BasePolicy
from qlib.backtest import CommonInfrastructure, Order
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange
from qlib.backtest.utils import LevelInfrastructure
from qlib.constant import ONE_MIN
from qlib.rl.data.native import load_backtest_data
@@ -235,6 +236,23 @@ class SAOEIntStrategy(SAOEStrategy):
if self._backtest:
self._env.reset()
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
assert hasattr(self.outer_trade_decision, "order_list")
trade_details = []
for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, "order_list")):
trade_details.append(
{
"instrument": o.stock_id,
"datetime": self.trade_calendar.get_step_time()[0],
"freq": self.trade_calendar.get_freq(),
"rl_exec_vol": v,
}
)
if a is not None:
trade_details[-1]["rl_action"] = a
return pd.DataFrame.from_records(trade_details)
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
states = []
obs_batch = []
@@ -261,4 +279,8 @@ class SAOEIntStrategy(SAOEStrategy):
order = cast(Order, decision)
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
return TradeDecisionWO(order_list=order_list, strategy=self)
return TradeDecisionWithDetails(
order_list=order_list,
strategy=self,
details=self._generate_trade_details(act, exec_vols),
)

View File

@@ -32,16 +32,7 @@ def get_order() -> Order:
)
def get_configs(order: Order) -> Tuple[dict, dict, dict]:
strategy_config = {
"class": "SingleOrderStrategy",
"module_path": "qlib.rl.strategy.single_order",
"kwargs": {
"order": order,
"trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()),
},
}
def get_configs(order: Order) -> Tuple[dict, dict]:
executor_config = {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
@@ -93,7 +84,7 @@ def get_configs(order: Order) -> Tuple[dict, dict, dict]:
"trade_unit": None,
}
return strategy_config, executor_config, exchange_config
return executor_config, exchange_config
def get_simulator(order: Order) -> SingleAssetOrderExecution:
@@ -115,12 +106,11 @@ def get_simulator(order: Order) -> SingleAssetOrderExecution:
}
# fmt: on
strategy_config, executor_config, exchange_config = get_configs(order)
executor_config, exchange_config = get_configs(order)
return SingleAssetOrderExecution(
order=order,
qlib_config=qlib_config,
strategy_config=strategy_config,
executor_config=executor_config,
exchange_config=exchange_config,
)