mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Migrate backtest logic from NT (#1263)
* Backtest migration * Minor bug fix in test * Reorganize file to avoid loop import * Fix test SAOE bug * Remove unnecessary names * Resolve PR comments; remove private classes; * Fix CI error * Resolve PR comments * Refactor data interfaces * Remove convert_instance_config and change config * Pylint issue * Pylint issue * Fix tempfile warning * Resolve PR comments * Add more comments
This commit is contained in:
@@ -114,7 +114,7 @@ def get_exchange(
|
||||
def create_account_instance(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
benchmark: str,
|
||||
benchmark: Optional[str],
|
||||
account: Union[float, int, dict],
|
||||
pos_type: str = "Position",
|
||||
) -> Account:
|
||||
@@ -163,7 +163,9 @@ def create_account_instance(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={
|
||||
benchmark_config={}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
@@ -176,7 +178,7 @@ def get_strategy_executor(
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
|
||||
231
qlib/rl/contrib/backtest.py
Normal file
231
qlib/rl/contrib/backtest.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
|
||||
def _get_multi_level_executor_config(
|
||||
strategy_config: dict,
|
||||
cash_limit: float = None,
|
||||
generate_report: bool = False,
|
||||
) -> dict:
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "1min",
|
||||
"verbose": False,
|
||||
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
|
||||
"generate_report": generate_report,
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
freqs = list(strategy_config.keys())
|
||||
freqs.sort(key=lambda x: pd.Timedelta(x))
|
||||
for freq in freqs:
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"inner_strategy": strategy_config[freq],
|
||||
"inner_executor": executor_config,
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
return executor_config
|
||||
|
||||
|
||||
def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
|
||||
if isinstance(executor, NestedExecutor):
|
||||
if hasattr(executor.inner_strategy, "set_env"):
|
||||
env = CollectDataEnvWrapper()
|
||||
env.reset()
|
||||
executor.inner_strategy.set_env(env)
|
||||
_set_env_for_all_strategy(executor.inner_executor)
|
||||
|
||||
|
||||
def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
|
||||
record_list = []
|
||||
for time, value_dict in indicator.items():
|
||||
if isinstance(value_dict, BaseOrderIndicator):
|
||||
# HACK: for qlib v0.8
|
||||
value_dict = value_dict.to_series()
|
||||
try:
|
||||
value_dict = {k: v for k, v in value_dict.items()}
|
||||
if value_dict["ffr"].empty:
|
||||
continue
|
||||
except Exception:
|
||||
value_dict = {k: v for k, v in value_dict.items() if k != "pa"}
|
||||
value_dict = pd.DataFrame(value_dict)
|
||||
value_dict["datetime"] = time
|
||||
record_list.append(value_dict)
|
||||
|
||||
if not record_list:
|
||||
return None
|
||||
|
||||
records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"})
|
||||
records = records.set_index(["instrument", "datetime"])
|
||||
return records
|
||||
|
||||
|
||||
def _generate_report(decisions: list, report_dict: dict) -> dict:
|
||||
report = {}
|
||||
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
|
||||
for key in ["1minute", "5minute", "30minute", "1day"]:
|
||||
if key not in report_dict["indicator"]:
|
||||
continue
|
||||
report[key] = report_dict["indicator"][key]
|
||||
report[key + "_obj"] = _convert_indicator_to_dataframe(
|
||||
report_dict["indicator"][key + "_obj"].order_indicator_his
|
||||
)
|
||||
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
|
||||
if len(cur_details) > 0:
|
||||
cur_details.pop("freq")
|
||||
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
|
||||
if "1minute" in report_dict["report"]:
|
||||
report["simulator"] = report_dict["report"]["1minute"][0]
|
||||
return report
|
||||
|
||||
|
||||
def single(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: str = "stock",
|
||||
cash_limit: float = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
top_strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"file": orders,
|
||||
"trade_range": TradeRangeByTime(
|
||||
pd.Timestamp(backtest_config["start_time"]).time(),
|
||||
pd.Timestamp(backtest_config["end_time"]).time(),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
|
||||
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
|
||||
tmp_backtest_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "1min",
|
||||
}
|
||||
)
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=tmp_backtest_config,
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
_set_env_for_all_strategy(executor=executor)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
|
||||
records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
report = _generate_report(decisions, report_dict)
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: report}
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def backtest(backtest_config: dict) -> pd.DataFrame:
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config["exchange"].pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
split="stock",
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
|
||||
output_path = Path(backtest_config["output_dir"])
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
res.to_csv(output_path / "summary.csv")
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
path = sys.argv[1]
|
||||
backtest(get_backtest_config_fromfile(path))
|
||||
103
qlib/rl/contrib/naive_config_parser.py
Normal file
103
qlib/rl/contrib/naive_config_parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from importlib import import_module
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def merge_a_into_b(a: dict, b: dict) -> dict:
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if isinstance(v, dict) and k in b:
|
||||
v.pop("_delete_", False) # TODO: make this more elegant
|
||||
b[k] = merge_a_into_b(v, b[k])
|
||||
else:
|
||||
b[k] = v
|
||||
return b
|
||||
|
||||
|
||||
def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
|
||||
if not os.path.isfile(filename):
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def parse_backtest_config(path: str) -> dict:
|
||||
abs_path = os.path.abspath(path)
|
||||
check_file_exist(abs_path)
|
||||
|
||||
file_ext_name = os.path.splitext(abs_path)[1]
|
||||
if file_ext_name not in (".py", ".json", ".yaml", ".yml"):
|
||||
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_config_dir:
|
||||
with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file:
|
||||
if platform.system() == "Windows":
|
||||
tmp_config_file.close()
|
||||
|
||||
tmp_config_name = os.path.basename(tmp_config_file.name)
|
||||
shutil.copyfile(abs_path, tmp_config_file.name)
|
||||
|
||||
if abs_path.endswith(".py"):
|
||||
tmp_module_name = os.path.splitext(tmp_config_name)[0]
|
||||
sys.path.insert(0, tmp_config_dir)
|
||||
module = import_module(tmp_module_name)
|
||||
sys.path.pop(0)
|
||||
|
||||
config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")}
|
||||
|
||||
del sys.modules[tmp_module_name]
|
||||
else:
|
||||
config = yaml.safe_load(open(tmp_config_file.name))
|
||||
|
||||
if "_base_" in config:
|
||||
base_file_name = config.pop("_base_")
|
||||
if not isinstance(base_file_name, list):
|
||||
base_file_name = [base_file_name]
|
||||
|
||||
for f in base_file_name:
|
||||
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
config = merge_a_into_b(a=config, b=base_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _convert_all_list_to_tuple(config: dict) -> dict:
|
||||
for k, v in config.items():
|
||||
if isinstance(v, list):
|
||||
config[k] = tuple(v)
|
||||
elif isinstance(v, dict):
|
||||
config[k] = _convert_all_list_to_tuple(v)
|
||||
return config
|
||||
|
||||
|
||||
def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = parse_backtest_config(path)
|
||||
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
"generate_report": False,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
|
||||
backtest_config_default = {
|
||||
"debug_single_stock": None,
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs/",
|
||||
# "runtime": {},
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
|
||||
return backtest_config
|
||||
29
qlib/rl/contrib/utils.py
Normal file
29
qlib/rl/contrib/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame:
|
||||
if isinstance(order_file, pd.DataFrame):
|
||||
return order_file
|
||||
|
||||
order_file = Path(order_file)
|
||||
|
||||
if order_file.suffix == ".pkl":
|
||||
order_df = pd.read_pickle(order_file).reset_index()
|
||||
elif order_file.suffix == ".csv":
|
||||
order_df = pd.read_csv(order_file)
|
||||
else:
|
||||
raise TypeError(f"Unsupported order file type: {order_file}")
|
||||
|
||||
if "date" in order_df.columns:
|
||||
# legacy dataframe columns
|
||||
order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"})
|
||||
order_df["datetime"] = order_df["datetime"].astype(str)
|
||||
|
||||
return order_df
|
||||
65
qlib/rl/data/base.py
Normal file
65
qlib/rl/data/base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseIntradayBacktestData:
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
|
||||
data type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_volume(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIntradayProcessedData:
|
||||
"""Processed market data after data cleanup and feature engineering.
|
||||
|
||||
It contains both processed data for "today" and "yesterday", as some algorithms
|
||||
might use the market information of the previous day to assist decision making.
|
||||
"""
|
||||
|
||||
today: pd.DataFrame
|
||||
"""Processed data for "today".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
yesterday: pd.DataFrame
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
|
||||
class ProcessedDataProvider:
|
||||
"""Provider of processed data"""
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
raise NotImplementedError
|
||||
@@ -41,7 +41,7 @@ class DataWrapper:
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
)
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
@@ -8,10 +9,12 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import ONE_DAY, EPS_T
|
||||
from qlib.constant import EPS_T, ONE_DAY
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
from qlib.utils.index_data import IndexData
|
||||
from .pickle_styled import BaseIntradayBacktestData
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -74,7 +77,7 @@ class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda order, _, __: order.key_by_day,
|
||||
)
|
||||
def load_qlib_backtest_data(
|
||||
def load_backtest_data(
|
||||
order: Order,
|
||||
trade_exchange: Exchange,
|
||||
trade_range: TradeRange,
|
||||
@@ -108,3 +111,40 @@ def load_qlib_backtest_data(
|
||||
ticks_for_order=ticks_for_order,
|
||||
)
|
||||
return backtest_data
|
||||
|
||||
|
||||
class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"])
|
||||
|
||||
self.today = _drop_stock_id(fetch_features(stock_id, date))
|
||||
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
)
|
||||
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
|
||||
return NTIntradayProcessedData(stock_id, date)
|
||||
|
||||
|
||||
class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_nt_intraday_processed_data(stock_id, date)
|
||||
@@ -19,7 +19,6 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence, cast
|
||||
@@ -30,6 +29,7 @@ import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from qlib.typehint import Literal
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
@@ -86,35 +86,6 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
return pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
|
||||
|
||||
class BaseIntradayBacktestData:
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
|
||||
data type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_volume(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data for simple simulator"""
|
||||
|
||||
@@ -178,20 +149,8 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
class IntradayProcessedData:
|
||||
"""Processed market data after data cleanup and feature engineering.
|
||||
|
||||
It contains both processed data for "today" and "yesterday", as some algorithms
|
||||
might use the market information of the previous day to assist decision making.
|
||||
"""
|
||||
|
||||
today: pd.DataFrame
|
||||
"""Processed data for "today".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
yesterday: pd.DataFrame
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
class IntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,18 +205,40 @@ def load_simple_intraday_backtest_data(
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
def load_pickled_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> IntradayProcessedData:
|
||||
) -> BaseIntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(self, data_dir: Path) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._data_dir = data_dir
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_pickled_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
feature_dim=feature_dim,
|
||||
time_index=time_index,
|
||||
)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path,
|
||||
start_time: pd.Timestamp = None,
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
from .utils.env_wrapper import BaseEnvWrapper
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
@@ -40,7 +39,7 @@ class Interpreter:
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: Optional[EnvWrapper] = None
|
||||
env: Optional[BaseEnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
@@ -74,7 +73,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: Optional[EnvWrapper] = None
|
||||
env: Optional[BaseEnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
|
||||
@@ -4,15 +4,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, List, cast
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.data.base import ProcessedDataProvider
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution.state import SAOEState
|
||||
from qlib.typehint import TypedDict
|
||||
@@ -25,6 +24,8 @@ __all__ = [
|
||||
"FullHistoryObs",
|
||||
]
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
|
||||
"""To 32-bit numeric types. Recursively."""
|
||||
@@ -57,8 +58,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_dir
|
||||
Path to load data after feature engineering.
|
||||
max_step
|
||||
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
|
||||
data_ticks
|
||||
@@ -66,21 +65,37 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
the total ticks is the length of day in minutes.
|
||||
data_dim
|
||||
Number of dimensions in data.
|
||||
processed_data_provider
|
||||
Provider of the processed data.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
|
||||
self.data_dir = data_dir
|
||||
# TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case.
|
||||
# TODO: So it should be redesigned after the data interface is well-designed.
|
||||
def __init__(
|
||||
self,
|
||||
max_step: int,
|
||||
data_ticks: int,
|
||||
data_dim: int,
|
||||
processed_data_provider: dict | ProcessedDataProvider,
|
||||
) -> None:
|
||||
self.max_step = max_step
|
||||
self.data_ticks = data_ticks
|
||||
self.data_dim = data_dim
|
||||
self.processed_data_provider: ProcessedDataProvider = init_instance_by_config(
|
||||
processed_data_provider,
|
||||
accept_types=ProcessedDataProvider,
|
||||
)
|
||||
|
||||
def interpret(self, state: SAOEState) -> FullHistoryObs:
|
||||
processed = pickle_styled.load_intraday_processed_data(
|
||||
self.data_dir,
|
||||
state.order.stock_id,
|
||||
pd.Timestamp(state.order.start_time.date()),
|
||||
self.data_dim,
|
||||
state.ticks_index,
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
processed = self.processed_data_provider.get_data(
|
||||
stock_id=state.order.stock_id,
|
||||
date=pd.Timestamp(state.order.start_time.date()),
|
||||
feature_dim=self.data_dim,
|
||||
time_index=state.ticks_index,
|
||||
)
|
||||
|
||||
position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
|
||||
@@ -96,15 +111,15 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
FullHistoryObs,
|
||||
canonicalize(
|
||||
{
|
||||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
"data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)),
|
||||
"data_processed_prev": np.array(processed.yesterday),
|
||||
"acquiring": _to_int32(state.order.direction == state.order.BUY),
|
||||
"cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),
|
||||
"cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)),
|
||||
"num_step": _to_int32(self.max_step),
|
||||
"target": _to_float32(state.order.amount),
|
||||
"position": _to_float32(state.position),
|
||||
"position_history": _to_float32(position_history[: self.max_step]),
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -162,6 +177,10 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
return spaces.Dict(space)
|
||||
|
||||
def interpret(self, state: SAOEState) -> CurrentStateObs:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
@@ -184,20 +203,31 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
Then when policy givens decision $x$, $a_x$ times order amount is the output.
|
||||
It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
max_step
|
||||
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | List[float]) -> None:
|
||||
def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Discrete:
|
||||
return spaces.Discrete(len(self.action_values))
|
||||
|
||||
def interpret(self, state: SAOEState, action: int) -> float:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert 0 <= action < len(self.action_values)
|
||||
return min(state.position, state.order.amount * self.action_values[action])
|
||||
assert self.env is not None
|
||||
if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1:
|
||||
return state.position
|
||||
else:
|
||||
return min(state.position, state.order.amount * self.action_values[action])
|
||||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
@@ -214,7 +244,19 @@ class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def interpret(self, state: SAOEState, action: float) -> float:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert self.env is not None
|
||||
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
|
||||
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
|
||||
return min(state.position, twap_volume * action)
|
||||
|
||||
|
||||
def _to_int32(val):
|
||||
return np.array(int(val), dtype=np.int32)
|
||||
|
||||
|
||||
def _to_float32(val):
|
||||
return np.array(val, dtype=np.float32)
|
||||
|
||||
@@ -117,3 +117,24 @@ class Recurrent(nn.Module):
|
||||
|
||||
out = torch.cat(sources, -1)
|
||||
return self.fc(out)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.q_net = nn.Linear(in_dim, out_dim)
|
||||
self.k_net = nn.Linear(in_dim, out_dim)
|
||||
self.v_net = nn.Linear(in_dim, out_dim)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
q = self.q_net(Q)
|
||||
k = self.k_net(K)
|
||||
v = self.v_net(V)
|
||||
|
||||
attn = torch.einsum("ijk,ilk->ijl", q, k)
|
||||
attn = attn.to(Q.device)
|
||||
attn_prob = torch.softmax(attn, dim=-1)
|
||||
|
||||
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
|
||||
|
||||
return attn_vec
|
||||
|
||||
@@ -11,7 +11,7 @@ from qlib.backtest.decision import Order
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.simulator import Simulator
|
||||
|
||||
from .integration import init_qlib
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from .state import SAOEState, SAOEStateAdapter
|
||||
from .strategy import SAOEStrategy
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from .state import SAOEMetrics, SAOEState
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SingleAssetOrderExecution"]
|
||||
__all__ = ["SingleAssetOrderExecutionSimple"]
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator.
|
||||
|
||||
As there's no "calendar" in the simple simulator, ticks are used to trade.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import cast, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -10,11 +11,13 @@ import pandas as pd
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.constant import EPS, ONE_MIN, REG_CN
|
||||
from qlib.rl.data.exchange_wrapper import IntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import BaseIntradayBacktestData
|
||||
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
|
||||
from qlib.typehint import TypedDict
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import IntradayBacktestData
|
||||
|
||||
|
||||
def _get_all_timestamps(
|
||||
|
||||
@@ -5,17 +5,23 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from types import GeneratorType
|
||||
from typing import Any, Optional, Union, cast, Dict, Generator
|
||||
from typing import Any, cast, Dict, Generator, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, Order
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import LevelInfrastructure
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data
|
||||
from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState
|
||||
from qlib.rl.data.native import load_backtest_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter
|
||||
from qlib.rl.utils.env_wrapper import BaseEnvWrapper
|
||||
from qlib.strategy.base import RLStrategy
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
class SAOEStrategy(RLStrategy):
|
||||
@@ -41,7 +47,7 @@ class SAOEStrategy(RLStrategy):
|
||||
self._last_step_range = (0, 0)
|
||||
|
||||
def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
|
||||
backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range)
|
||||
backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)
|
||||
|
||||
return SAOEStateAdapter(
|
||||
order=order,
|
||||
@@ -106,7 +112,10 @@ class SAOEStrategy(RLStrategy):
|
||||
|
||||
return decision
|
||||
|
||||
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
def _generate_trade_decision(
|
||||
self,
|
||||
execute_result: list = None,
|
||||
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -146,3 +155,110 @@ class ProxySAOEStrategy(SAOEStrategy):
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SAOEIntStrategy(SAOEStrategy):
|
||||
"""(SAOE)state based strategy with (Int)preters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: dict | BasePolicy,
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: object = None, # TODO: add accurate typehint later.
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
backtest: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super(SAOEIntStrategy, self).__init__(
|
||||
policy=policy,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
level_infra=level_infra,
|
||||
common_infra=common_infra,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._backtest = backtest
|
||||
|
||||
self._state_interpreter: StateInterpreter = init_instance_by_config(
|
||||
state_interpreter,
|
||||
accept_types=StateInterpreter,
|
||||
)
|
||||
self._action_interpreter: ActionInterpreter = init_instance_by_config(
|
||||
action_interpreter,
|
||||
accept_types=ActionInterpreter,
|
||||
)
|
||||
|
||||
if isinstance(policy, dict):
|
||||
assert network is not None
|
||||
|
||||
if isinstance(network, dict):
|
||||
network["kwargs"].update(
|
||||
{
|
||||
"obs_space": self._state_interpreter.observation_space,
|
||||
}
|
||||
)
|
||||
network_inst = init_instance_by_config(network)
|
||||
else:
|
||||
network_inst = network
|
||||
|
||||
policy["kwargs"].update(
|
||||
{
|
||||
"obs_space": self._state_interpreter.observation_space,
|
||||
"action_space": self._action_interpreter.action_space,
|
||||
"network": network_inst,
|
||||
}
|
||||
)
|
||||
self._policy = init_instance_by_config(policy)
|
||||
elif isinstance(policy, BasePolicy):
|
||||
self._policy = policy
|
||||
else:
|
||||
raise ValueError(f"Unsupported policy type: {type(policy)}.")
|
||||
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
def set_env(self, env: BaseEnvWrapper) -> None:
|
||||
# TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper.
|
||||
# We should decompose the interpreters with EnvWrapper in the future and we should remove this method
|
||||
# after that.
|
||||
|
||||
self._env = env
|
||||
self._state_interpreter.env = self._action_interpreter.env = self._env
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
|
||||
# In backtest, env.reset() needs to be manually called since there is no outer trainer to call it
|
||||
if self._backtest:
|
||||
self._env.reset()
|
||||
|
||||
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
states = []
|
||||
obs_batch = []
|
||||
for decision in self.outer_trade_decision.get_decision():
|
||||
order = cast(Order, decision)
|
||||
state = self.get_saoe_state_by_order(order)
|
||||
|
||||
states.append(state)
|
||||
obs_batch.append({"obs": self._state_interpreter.interpret(state)})
|
||||
|
||||
with torch.no_grad():
|
||||
policy_out = self._policy(Batch(obs_batch))
|
||||
act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act
|
||||
exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]
|
||||
|
||||
# In backtest, env.step() needs to be manually called since there is no outer trainer to call it
|
||||
if self._backtest:
|
||||
self._env.step(None)
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
if exec_vol != 0:
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWO(order_list=order_list, strategy=self)
|
||||
|
||||
@@ -4,6 +4,6 @@
|
||||
"""Train, test, inference utilities."""
|
||||
|
||||
from .api import backtest, train
|
||||
from .callbacks import EarlyStopping, Checkpoint
|
||||
from .callbacks import Checkpoint, EarlyStopping
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel, TrainingVesselBase
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
|
||||
from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
@@ -14,7 +14,6 @@ from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, State
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
from .log import LogCollector, LogLevel
|
||||
|
||||
@@ -49,9 +48,24 @@ class EnvWrapperStatus(TypedDict):
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
class BaseEnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Base env wrapper for RL environments. It has two implementations:
|
||||
- EnvWrapper: Qlib-based RL environment used in training.
|
||||
- CollectDataEnvWrapper: Dummy environment used in collect_data_loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in BaseEnvWrapper.")
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
@@ -115,6 +129,8 @@ class EnvWrapper(
|
||||
# 3. Avoid circular reference.
|
||||
# 4. When the components get serialized, we can throw away the env without any burden.
|
||||
# (though this part is not implemented yet)
|
||||
super().__init__()
|
||||
|
||||
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
|
||||
if obj is not None:
|
||||
obj.env = weakref.proxy(self) # type: ignore
|
||||
@@ -247,5 +263,19 @@ class EnvWrapper(
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
||||
class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
|
||||
"""Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop."""
|
||||
|
||||
def reset(self, **kwargs: Any) -> None:
|
||||
self.status = EnvWrapperStatus(
|
||||
cur_step=0,
|
||||
done=False,
|
||||
initial_state=None,
|
||||
obs_history=[],
|
||||
action_history=[],
|
||||
reward_history=[],
|
||||
)
|
||||
|
||||
def step(self, policy_action: Any = None, **kwargs: Any) -> None:
|
||||
self.status["cur_step"] += 1
|
||||
|
||||
@@ -11,6 +11,7 @@ from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
@@ -192,6 +193,8 @@ def test_interpreter() -> None:
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
interpreter_action.env = CollectDataEnvWrapper()
|
||||
interpreter_action.env.reset()
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
|
||||
@@ -16,9 +16,11 @@ from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.data.pickle_styled import PickleProcessedDataProvider
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
@@ -40,16 +42,15 @@ def test_pickle_data_inspect():
|
||||
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
assert len(data) == 390
|
||||
|
||||
data = pickle_styled.load_intraday_processed_data(
|
||||
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
|
||||
)
|
||||
provider = PickleProcessedDataProvider(DATA_DIR / "processed")
|
||||
data = provider.get_data("AAL", "2013-12-11", 5, data.get_time_index())
|
||||
assert len(data.today) == len(data.yesterday) == 390
|
||||
|
||||
|
||||
def test_simulator_first_step():
|
||||
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
|
||||
assert state.position == 30.0
|
||||
@@ -83,7 +84,7 @@ def test_simulator_first_step():
|
||||
def test_simulator_stop_twap():
|
||||
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
for _ in range(13):
|
||||
simulator.step(1.0)
|
||||
|
||||
@@ -106,10 +107,10 @@ def test_simulator_stop_early():
|
||||
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator.step(2.0)
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator.step(1.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
@@ -119,7 +120,7 @@ def test_simulator_stop_early():
|
||||
def test_simulator_start_middle():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
simulator.step(2.0)
|
||||
@@ -138,7 +139,7 @@ def test_simulator_start_middle():
|
||||
def test_interpreter():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
|
||||
@@ -146,7 +147,7 @@ def test_interpreter():
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
status: EnvWrapperStatus
|
||||
|
||||
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
|
||||
interpreter_step = CurrentStepStateInterpreter(13)
|
||||
interpreter_action = CategoricalActionInterpreter(20)
|
||||
interpreter_action_twap = TwapRelativeActionInterpreter()
|
||||
@@ -185,6 +186,10 @@ def test_interpreter():
|
||||
assert np.sum(obs["data_processed"][60:]) == 0
|
||||
|
||||
# second step: action
|
||||
interpreter_action.env = CollectDataEnvWrapper()
|
||||
interpreter_action_twap.env = CollectDataEnvWrapper()
|
||||
interpreter_action.env.reset()
|
||||
interpreter_action_twap.env.reset()
|
||||
action = interpreter_action(simulator.get_state(), 1)
|
||||
assert action == 15 / 20
|
||||
|
||||
@@ -219,13 +224,13 @@ def test_network_sanity():
|
||||
# we won't check the correctness of networks here
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 390
|
||||
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
status: EnvWrapperStatus
|
||||
|
||||
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
|
||||
action_interp = CategoricalActionInterpreter(13)
|
||||
|
||||
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
|
||||
@@ -253,13 +258,15 @@ def test_twap_strategy(finite_env_type):
|
||||
orders = pickle_styled.load_orders(ORDER_DIR)
|
||||
assert len(orders) == 248
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
|
||||
action_interp = TwapRelativeActionInterpreter()
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
policy = AllOne(state_interp.observation_space, action_interp.action_space)
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -282,15 +289,17 @@ def test_cn_ppo_strategy():
|
||||
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
|
||||
assert len(orders) == 40
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
|
||||
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -313,13 +322,15 @@ def test_ppo_train():
|
||||
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
|
||||
assert len(orders) == 40
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
|
||||
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
train(
|
||||
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
|
||||
Reference in New Issue
Block a user