1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Migrate backtest logic from NT (#1263)

* Backtest migration

* Minor bug fix in test

* Reorganize file to avoid loop import

* Fix test SAOE bug

* Remove unnecessary names

* Resolve PR comments; remove private classes;

* Fix CI error

* Resolve PR comments

* Refactor data interfaces

* Remove convert_instance_config and change config

* Pylint issue

* Pylint issue

* Fix tempfile warning

* Resolve PR comments

* Add more comments
This commit is contained in:
Huoran Li
2022-09-19 14:54:26 +08:00
committed by GitHub
parent e762548295
commit bee05f56ef
19 changed files with 794 additions and 118 deletions

View File

@@ -114,7 +114,7 @@ def get_exchange(
def create_account_instance(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
benchmark: str,
benchmark: Optional[str],
account: Union[float, int, dict],
pos_type: str = "Position",
) -> Account:
@@ -163,7 +163,9 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
@@ -176,7 +178,7 @@ def get_strategy_executor(
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",

231
qlib/rl/contrib/backtest.py Normal file
View File

@@ -0,0 +1,231 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
import pickle
import sys
from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float = None,
generate_report: bool = False,
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
"track_data": True,
},
}
freqs = list(strategy_config.keys())
freqs.sort(key=lambda x: pd.Timedelta(x))
for freq in freqs:
executor_config = {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq,
"inner_strategy": strategy_config[freq],
"inner_executor": executor_config,
"track_data": True,
},
}
return executor_config
def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
if isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
env = CollectDataEnvWrapper()
env.reset()
executor.inner_strategy.set_env(env)
_set_env_for_all_strategy(executor.inner_executor)
def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
record_list = []
for time, value_dict in indicator.items():
if isinstance(value_dict, BaseOrderIndicator):
# HACK: for qlib v0.8
value_dict = value_dict.to_series()
try:
value_dict = {k: v for k, v in value_dict.items()}
if value_dict["ffr"].empty:
continue
except Exception:
value_dict = {k: v for k, v in value_dict.items() if k != "pa"}
value_dict = pd.DataFrame(value_dict)
value_dict["datetime"] = time
record_list.append(value_dict)
if not record_list:
return None
records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"})
records = records.set_index(["instrument", "datetime"])
return records
def _generate_report(decisions: list, report_dict: dict) -> dict:
report = {}
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
for key in ["1minute", "5minute", "30minute", "1day"]:
if key not in report_dict["indicator"]:
continue
report[key] = report_dict["indicator"][key]
report[key + "_obj"] = _convert_indicator_to_dataframe(
report_dict["indicator"][key + "_obj"].order_indicator_his
)
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
if len(cur_details) > 0:
cur_details.pop("freq")
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
if "1minute" in report_dict["report"]:
report["simulator"] = report_dict["report"]["1minute"][0]
return report
def single(
backtest_config: dict,
orders: pd.DataFrame,
split: str = "stock",
cash_limit: float = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()
top_strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"file": orders,
"trade_range": TradeRangeByTime(
pd.Timestamp(backtest_config["start_time"]).time(),
pd.Timestamp(backtest_config["end_time"]).time(),
),
},
}
top_executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
)
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
tmp_backtest_config.update(
{
"codes": stocks,
"freq": "1min",
}
)
strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
strategy=top_strategy_config,
executor=top_executor_config,
benchmark=None,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=tmp_backtest_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
_set_env_for_all_strategy(executor=executor)
report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
report = _generate_report(decisions, report_dict)
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
else:
day = orders.iloc[0].datetime
report = {day: report}
return records, report
else:
return records
def backtest(backtest_config: dict) -> pd.DataFrame:
order_df = read_order_file(backtest_config["order_file"])
cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config["exchange"].pop("generate_report")
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
res = Parallel(**mp_config)(
delayed(single)(
backtest_config=backtest_config,
orders=order_df[order_df["instrument"] == stock].copy(),
split="stock",
cash_limit=cash_limit,
generate_report=generate_report,
)
for stock in stock_pool
)
output_path = Path(backtest_config["output_dir"])
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)
res.to_csv(output_path / "summary.csv")
return res
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
path = sys.argv[1]
backtest(get_backtest_config_fromfile(path))

View File

@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import platform
import shutil
import sys
import tempfile
from importlib import import_module
import yaml
def merge_a_into_b(a: dict, b: dict) -> dict:
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b:
v.pop("_delete_", False) # TODO: make this more elegant
b[k] = merge_a_into_b(v, b[k])
else:
b[k] = v
return b
def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
if not os.path.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
def parse_backtest_config(path: str) -> dict:
abs_path = os.path.abspath(path)
check_file_exist(abs_path)
file_ext_name = os.path.splitext(abs_path)[1]
if file_ext_name not in (".py", ".json", ".yaml", ".yml"):
raise IOError("Only py/yml/yaml/json type are supported now!")
with tempfile.TemporaryDirectory() as tmp_config_dir:
with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file:
if platform.system() == "Windows":
tmp_config_file.close()
tmp_config_name = os.path.basename(tmp_config_file.name)
shutil.copyfile(abs_path, tmp_config_file.name)
if abs_path.endswith(".py"):
tmp_module_name = os.path.splitext(tmp_config_name)[0]
sys.path.insert(0, tmp_config_dir)
module = import_module(tmp_module_name)
sys.path.pop(0)
config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")}
del sys.modules[tmp_module_name]
else:
config = yaml.safe_load(open(tmp_config_file.name))
if "_base_" in config:
base_file_name = config.pop("_base_")
if not isinstance(base_file_name, list):
base_file_name = [base_file_name]
for f in base_file_name:
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
config = merge_a_into_b(a=config, b=base_config)
return config
def _convert_all_list_to_tuple(config: dict) -> dict:
for k, v in config.items():
if isinstance(v, list):
config[k] = tuple(v)
elif isinstance(v, dict):
config[k] = _convert_all_list_to_tuple(v)
return config
def get_backtest_config_fromfile(path: str) -> dict:
backtest_config = parse_backtest_config(path)
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
"generate_report": False,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
backtest_config_default = {
"debug_single_stock": None,
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
# "runtime": {},
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
return backtest_config

29
qlib/rl/contrib/utils.py Normal file
View File

@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from pathlib import Path
import pandas as pd
def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame:
if isinstance(order_file, pd.DataFrame):
return order_file
order_file = Path(order_file)
if order_file.suffix == ".pkl":
order_df = pd.read_pickle(order_file).reset_index()
elif order_file.suffix == ".csv":
order_df = pd.read_csv(order_file)
else:
raise TypeError(f"Unsupported order file type: {order_file}")
if "date" in order_df.columns:
# legacy dataframe columns
order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"})
order_df["datetime"] = order_df["datetime"].astype(str)
return order_df

65
qlib/rl/data/base.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from abc import abstractmethod
import pandas as pd
class BaseIntradayBacktestData:
"""
Raw market data that is often used in backtesting (thus called BacktestData).
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
data type.
"""
@abstractmethod
def __repr__(self) -> str:
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError
@abstractmethod
def get_deal_price(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_volume(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_time_index(self) -> pd.DatetimeIndex:
raise NotImplementedError
class BaseIntradayProcessedData:
"""Processed market data after data cleanup and feature engineering.
It contains both processed data for "today" and "yesterday", as some algorithms
might use the market information of the previous day to assist decision making.
"""
today: pd.DataFrame
"""Processed data for "today".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
yesterday: pd.DataFrame
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
class ProcessedDataProvider:
"""Provider of processed data"""
def get_data(
self,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
raise NotImplementedError

View File

@@ -41,7 +41,7 @@ class DataWrapper:
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
)
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
@@ -8,10 +9,12 @@ import pandas as pd
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.constant import ONE_DAY, EPS_T
from qlib.constant import EPS_T, ONE_DAY
from qlib.rl.order_execution.utils import get_ticks_slice
from qlib.utils.index_data import IndexData
from .pickle_styled import BaseIntradayBacktestData
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -74,7 +77,7 @@ class IntradayBacktestData(BaseIntradayBacktestData):
cache=cachetools.LRUCache(100),
key=lambda order, _, __: order.key_by_day,
)
def load_qlib_backtest_data(
def load_backtest_data(
order: Order,
trade_exchange: Exchange,
trade_range: TradeRange,
@@ -108,3 +111,40 @@ def load_qlib_backtest_data(
ticks_for_order=ticks_for_order,
)
return backtest_data
class NTIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
def __init__(
self,
stock_id: str,
date: pd.Timestamp,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"])
self.today = _drop_stock_id(fetch_features(stock_id, date))
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
)
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
return NTIntradayProcessedData(stock_id, date)
class NTProcessedDataProvider(ProcessedDataProvider):
def get_data(
self,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_nt_intraday_processed_data(stock_id, date)

View File

@@ -19,7 +19,6 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
from __future__ import annotations
from abc import abstractmethod
from functools import lru_cache
from pathlib import Path
from typing import List, Sequence, cast
@@ -30,6 +29,7 @@ import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import Order, OrderDir
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from qlib.typehint import Literal
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
@@ -86,35 +86,6 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
return pd.read_pickle(_find_pickle(filename_without_suffix))
class BaseIntradayBacktestData:
"""
Raw market data that is often used in backtesting (thus called BacktestData).
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
data type.
"""
@abstractmethod
def __repr__(self) -> str:
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError
@abstractmethod
def get_deal_price(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_volume(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_time_index(self) -> pd.DatetimeIndex:
raise NotImplementedError
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
"""Backtest data for simple simulator"""
@@ -178,20 +149,8 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData:
"""Processed market data after data cleanup and feature engineering.
It contains both processed data for "today" and "yesterday", as some algorithms
might use the market information of the previous day to assist decision making.
"""
today: pd.DataFrame
"""Processed data for "today".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
yesterday: pd.DataFrame
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
class IntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
def __init__(
self,
@@ -246,18 +205,40 @@ def load_simple_intraday_backtest_data(
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
)
def load_intraday_processed_data(
def load_pickled_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> IntradayProcessedData:
) -> BaseIntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
class PickleProcessedDataProvider(ProcessedDataProvider):
def __init__(self, data_dir: Path) -> None:
super().__init__()
self._data_dir = data_dir
def get_data(
self,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_pickled_intraday_processed_data(
data_dir=self._data_dir,
stock_id=stock_id,
date=date,
feature_dim=feature_dim,
time_index=time_index,
)
def load_orders(
order_path: Path,
start_time: pd.Timestamp = None,

View File

@@ -3,16 +3,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
import numpy as np
from qlib.typehint import final
from .simulator import ActType, StateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
from .utils.env_wrapper import BaseEnvWrapper
import gym
from gym import spaces
@@ -40,7 +39,7 @@ class Interpreter:
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: Optional[EnvWrapper] = None
env: Optional[BaseEnvWrapper] = None
@property
def observation_space(self) -> gym.Space:
@@ -74,7 +73,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
env: Optional[EnvWrapper] = None
env: Optional[BaseEnvWrapper] = None
@property
def action_space(self) -> gym.Space:

View File

@@ -4,15 +4,14 @@
from __future__ import annotations
import math
from pathlib import Path
from typing import Any, List, cast
from typing import Any, List, Optional, cast
import numpy as np
import pandas as pd
from gym import spaces
from qlib.constant import EPS
from qlib.rl.data import pickle_styled
from qlib.rl.data.base import ProcessedDataProvider
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution.state import SAOEState
from qlib.typehint import TypedDict
@@ -25,6 +24,8 @@ __all__ = [
"FullHistoryObs",
]
from qlib.utils import init_instance_by_config
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
"""To 32-bit numeric types. Recursively."""
@@ -57,8 +58,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
Parameters
----------
data_dir
Path to load data after feature engineering.
max_step
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
data_ticks
@@ -66,21 +65,37 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
the total ticks is the length of day in minutes.
data_dim
Number of dimensions in data.
processed_data_provider
Provider of the processed data.
"""
def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
self.data_dir = data_dir
# TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case.
# TODO: So it should be redesigned after the data interface is well-designed.
def __init__(
self,
max_step: int,
data_ticks: int,
data_dim: int,
processed_data_provider: dict | ProcessedDataProvider,
) -> None:
self.max_step = max_step
self.data_ticks = data_ticks
self.data_dim = data_dim
self.processed_data_provider: ProcessedDataProvider = init_instance_by_config(
processed_data_provider,
accept_types=ProcessedDataProvider,
)
def interpret(self, state: SAOEState) -> FullHistoryObs:
processed = pickle_styled.load_intraday_processed_data(
self.data_dir,
state.order.stock_id,
pd.Timestamp(state.order.start_time.date()),
self.data_dim,
state.ticks_index,
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
processed = self.processed_data_provider.get_data(
stock_id=state.order.stock_id,
date=pd.Timestamp(state.order.start_time.date()),
feature_dim=self.data_dim,
time_index=state.ticks_index,
)
position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
@@ -96,15 +111,15 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
FullHistoryObs,
canonicalize(
{
"data_processed": self._mask_future_info(processed.today, state.cur_time),
"data_processed_prev": processed.yesterday,
"acquiring": state.order.direction == state.order.BUY,
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
"position_history": position_history[: self.max_step],
"data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)),
"data_processed_prev": np.array(processed.yesterday),
"acquiring": _to_int32(state.order.direction == state.order.BUY),
"cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),
"cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)),
"num_step": _to_int32(self.max_step),
"target": _to_float32(state.order.amount),
"position": _to_float32(state.position),
"position_history": _to_float32(position_history[: self.max_step]),
},
),
)
@@ -162,6 +177,10 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
return spaces.Dict(space)
def interpret(self, state: SAOEState) -> CurrentStateObs:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
obs = CurrentStateObs(
@@ -184,20 +203,31 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
Then when policy givens decision $x$, $a_x$ times order amount is the output.
It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
max_step
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
"""
def __init__(self, values: int | List[float]) -> None:
def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None:
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
self.max_step = max_step
@property
def action_space(self) -> spaces.Discrete:
return spaces.Discrete(len(self.action_values))
def interpret(self, state: SAOEState, action: int) -> float:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert 0 <= action < len(self.action_values)
return min(state.position, state.order.amount * self.action_values[action])
assert self.env is not None
if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1:
return state.position
else:
return min(state.position, state.order.amount * self.action_values[action])
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
@@ -214,7 +244,19 @@ class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
def interpret(self, state: SAOEState, action: float) -> float:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert self.env is not None
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
return min(state.position, twap_volume * action)
def _to_int32(val):
return np.array(int(val), dtype=np.int32)
def _to_float32(val):
return np.array(val, dtype=np.float32)

View File

@@ -117,3 +117,24 @@ class Recurrent(nn.Module):
out = torch.cat(sources, -1)
return self.fc(out)
class Attention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.q_net = nn.Linear(in_dim, out_dim)
self.k_net = nn.Linear(in_dim, out_dim)
self.v_net = nn.Linear(in_dim, out_dim)
def forward(self, Q, K, V):
q = self.q_net(Q)
k = self.k_net(K)
v = self.v_net(V)
attn = torch.einsum("ijk,ilk->ijl", q, k)
attn = attn.to(Q.device)
attn_prob = torch.softmax(attn, dim=-1)
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
return attn_vec

View File

@@ -11,7 +11,7 @@ from qlib.backtest.decision import Order
from qlib.backtest.executor import NestedExecutor
from qlib.rl.simulator import Simulator
from .integration import init_qlib
from qlib.rl.data.integration import init_qlib
from .state import SAOEState, SAOEStateAdapter
from .strategy import SAOEStrategy

View File

@@ -18,10 +18,10 @@ from .state import SAOEMetrics, SAOEState
# TODO: Integrating Qlib's native data with simulator_simple
__all__ = ["SingleAssetOrderExecution"]
__all__ = ["SingleAssetOrderExecutionSimple"]
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator.
As there's no "calendar" in the simple simulator, ticks are used to trade.

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import typing
from typing import cast, NamedTuple, Optional, Tuple
import numpy as np
@@ -10,11 +11,13 @@ import pandas as pd
from qlib.backtest import Exchange, Order
from qlib.backtest.executor import BaseExecutor
from qlib.constant import EPS, ONE_MIN, REG_CN
from qlib.rl.data.exchange_wrapper import IntradayBacktestData
from qlib.rl.data.pickle_styled import BaseIntradayBacktestData
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
from qlib.typehint import TypedDict
from qlib.utils.time import get_day_min_idx_range
from typing_extensions import TypedDict
if typing.TYPE_CHECKING:
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import IntradayBacktestData
def _get_all_timestamps(

View File

@@ -5,17 +5,23 @@ from __future__ import annotations
import collections
from types import GeneratorType
from typing import Any, Optional, Union, cast, Dict, Generator
from typing import Any, cast, Dict, Generator, Optional, Union
import pandas as pd
import torch
from tianshou.data import Batch
from tianshou.policy import BasePolicy
from qlib.backtest import CommonInfrastructure, Order
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange
from qlib.backtest.utils import LevelInfrastructure
from qlib.constant import ONE_MIN
from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data
from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState
from qlib.rl.data.native import load_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter
from qlib.rl.utils.env_wrapper import BaseEnvWrapper
from qlib.strategy.base import RLStrategy
from qlib.utils import init_instance_by_config
class SAOEStrategy(RLStrategy):
@@ -41,7 +47,7 @@ class SAOEStrategy(RLStrategy):
self._last_step_range = (0, 0)
def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range)
backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)
return SAOEStateAdapter(
order=order,
@@ -106,7 +112,10 @@ class SAOEStrategy(RLStrategy):
return decision
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
def _generate_trade_decision(
self,
execute_result: list = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
raise NotImplementedError
@@ -146,3 +155,110 @@ class ProxySAOEStrategy(SAOEStrategy):
order_list = outer_trade_decision.order_list
assert len(order_list) == 1
self._order = order_list[0]
class SAOEIntStrategy(SAOEStrategy):
"""(SAOE)state based strategy with (Int)preters."""
def __init__(
self,
policy: dict | BasePolicy,
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: object = None, # TODO: add accurate typehint later.
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
backtest: bool = False,
**kwargs: Any,
) -> None:
super(SAOEIntStrategy, self).__init__(
policy=policy,
outer_trade_decision=outer_trade_decision,
level_infra=level_infra,
common_infra=common_infra,
**kwargs,
)
self._backtest = backtest
self._state_interpreter: StateInterpreter = init_instance_by_config(
state_interpreter,
accept_types=StateInterpreter,
)
self._action_interpreter: ActionInterpreter = init_instance_by_config(
action_interpreter,
accept_types=ActionInterpreter,
)
if isinstance(policy, dict):
assert network is not None
if isinstance(network, dict):
network["kwargs"].update(
{
"obs_space": self._state_interpreter.observation_space,
}
)
network_inst = init_instance_by_config(network)
else:
network_inst = network
policy["kwargs"].update(
{
"obs_space": self._state_interpreter.observation_space,
"action_space": self._action_interpreter.action_space,
"network": network_inst,
}
)
self._policy = init_instance_by_config(policy)
elif isinstance(policy, BasePolicy):
self._policy = policy
else:
raise ValueError(f"Unsupported policy type: {type(policy)}.")
if self._policy is not None:
self._policy.eval()
def set_env(self, env: BaseEnvWrapper) -> None:
# TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper.
# We should decompose the interpreters with EnvWrapper in the future and we should remove this method
# after that.
self._env = env
self._state_interpreter.env = self._action_interpreter.env = self._env
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
# In backtest, env.reset() needs to be manually called since there is no outer trainer to call it
if self._backtest:
self._env.reset()
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
states = []
obs_batch = []
for decision in self.outer_trade_decision.get_decision():
order = cast(Order, decision)
state = self.get_saoe_state_by_order(order)
states.append(state)
obs_batch.append({"obs": self._state_interpreter.interpret(state)})
with torch.no_grad():
policy_out = self._policy(Batch(obs_batch))
act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act
exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]
# In backtest, env.step() needs to be manually called since there is no outer trainer to call it
if self._backtest:
self._env.step(None)
oh = self.trade_exchange.get_order_helper()
order_list = []
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
if exec_vol != 0:
order = cast(Order, decision)
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
return TradeDecisionWO(order_list=order_list, strategy=self)

View File

@@ -4,6 +4,6 @@
"""Train, test, inference utilities."""
from .api import backtest, train
from .callbacks import EarlyStopping, Checkpoint
from .callbacks import Checkpoint, EarlyStopping
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import weakref
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple
import gym
from gym import Space
@@ -14,7 +14,6 @@ from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, State
from qlib.rl.reward import Reward
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
from qlib.typehint import TypedDict
from .finite_env import generate_nan_observation
from .log import LogCollector, LogLevel
@@ -49,9 +48,24 @@ class EnvWrapperStatus(TypedDict):
reward_history: list
class EnvWrapper(
class BaseEnvWrapper(
gym.Env[ObsType, PolicyActType],
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Base env wrapper for RL environments. It has two implementations:
- EnvWrapper: Qlib-based RL environment used in training.
- CollectDataEnvWrapper: Dummy environment used in collect_data_loop.
"""
def __init__(self) -> None:
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in BaseEnvWrapper.")
class EnvWrapper(
BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
@@ -115,6 +129,8 @@ class EnvWrapper(
# 3. Avoid circular reference.
# 4. When the components get serialized, we can throw away the env without any burden.
# (though this part is not implemented yet)
super().__init__()
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
if obj is not None:
obj.env = weakref.proxy(self) # type: ignore
@@ -247,5 +263,19 @@ class EnvWrapper(
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in EnvWrapper.")
class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
"""Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop."""
def reset(self, **kwargs: Any) -> None:
self.status = EnvWrapperStatus(
cur_step=0,
done=False,
initial_state=None,
obs_history=[],
action_history=[],
reward_history=[],
)
def step(self, policy_action: Any = None, **kwargs: Any) -> None:
self.status["cur_step"] += 1

View File

@@ -11,6 +11,7 @@ from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import SimulatorExecutor
from qlib.rl.order_execution import CategoricalActionInterpreter
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
TOTAL_POSITION = 2100.0
@@ -192,6 +193,8 @@ def test_interpreter() -> None:
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
interpreter_action.env = CollectDataEnvWrapper()
interpreter_action.env.reset()
NUM_STEPS = 7
state = simulator.get_state()

View File

@@ -16,9 +16,11 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.data.pickle_styled import PickleProcessedDataProvider
from qlib.rl.order_execution import *
from qlib.rl.trainer import backtest, train
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
@@ -40,16 +42,15 @@ def test_pickle_data_inspect():
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
)
provider = PickleProcessedDataProvider(DATA_DIR / "processed")
data = provider.get_data("AAL", "2013-12-11", 5, data.get_time_index())
assert len(data.today) == len(data.yesterday) == 390
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
@@ -83,7 +84,7 @@ def test_simulator_first_step():
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
for _ in range(13):
simulator.step(1.0)
@@ -106,10 +107,10 @@ def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
@@ -119,7 +120,7 @@ def test_simulator_stop_early():
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
@@ -138,7 +139,7 @@ def test_simulator_start_middle():
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
@@ -146,7 +147,7 @@ def test_interpreter():
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
interpreter_step = CurrentStepStateInterpreter(13)
interpreter_action = CategoricalActionInterpreter(20)
interpreter_action_twap = TwapRelativeActionInterpreter()
@@ -185,6 +186,10 @@ def test_interpreter():
assert np.sum(obs["data_processed"][60:]) == 0
# second step: action
interpreter_action.env = CollectDataEnvWrapper()
interpreter_action_twap.env = CollectDataEnvWrapper()
interpreter_action.env.reset()
interpreter_action_twap.env.reset()
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20
@@ -219,13 +224,13 @@ def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(13)
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
@@ -253,13 +258,15 @@ def test_twap_strategy(finite_env_type):
orders = pickle_styled.load_orders(ORDER_DIR)
assert len(orders) == 248
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
action_interp = TwapRelativeActionInterpreter()
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -282,15 +289,17 @@ def test_cn_ppo_strategy():
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -313,13 +322,15 @@ def test_ppo_train():
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,