1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

Refine RL todos (#1332)

* Refine several todos

* CI issues

* Remove Dropna limitation of `quote_df` in Exchange  (#1334)

* Remove Dropna limitation of `quote_df` of Exchange

* Impreove docstring

* Fix type error when expression is specified (#1335)

* Refine fill_missing_data()

* Remove several TODO comments

* Add back env for interpreters

* Change Literal import

* Resolve PR comments

* Move  to SAOEState

* Add Trainer.get_policy_state_dict()

* Mypy issue

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Huoran Li
2022-11-10 21:10:11 +08:00
committed by GitHub
parent 49a5bccfec
commit 35794846ff
20 changed files with 461 additions and 530 deletions

View File

@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
import pandas as pd
from .account import Account
from .report import Indicator, PortfolioMetrics
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
@@ -20,7 +19,7 @@ if TYPE_CHECKING:
from ..config import C
from ..log import get_module_logger
from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .utils import CommonInfrastructure
@@ -223,7 +222,7 @@ def backtest(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution
@@ -256,9 +255,9 @@ def backtest(
Returns
-------
portfolio_metrics_dict: Dict[PortfolioMetrics]
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator_dict: Dict[Indicator]
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
It is organized in a dict format
@@ -273,8 +272,7 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return portfolio_metrics, indicator
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)
def collect_data(

View File

@@ -3,12 +3,12 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
import pandas as pd
from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.report import Indicator, PortfolioMetrics
from qlib.backtest.report import Indicator
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
@@ -19,30 +19,35 @@ from tqdm.auto import tqdm
from ..utils.time import Freq
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
def backtest_loop(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
please refer to the docs of `collect_data_loop`
Returns
-------
portfolio_metrics: PortfolioMetrics
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator: Indicator
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
"""
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict"))
indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict"))
return portfolio_dict, indicator_dict
def collect_data_loop(
@@ -89,14 +94,17 @@ def collect_data_loop(
if return_value is not None:
all_executors = trade_executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
portfolio_dict: PORT_METRIC = {}
indicator_dict: INDICATOR_METRIC = {}
for executor in all_executors:
key = "{}{}".format(*Freq.parse(executor.time_per_step))
if executor.trade_account.is_port_metr_enabled():
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()
indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
indicator_obj = executor.trade_account.get_trade_indicator()
indicator_dict[key] = (indicator_df, indicator_obj)
return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})

View File

@@ -26,6 +26,15 @@ from .high_performance_ds import BaseQuote, NumpyQuote
class Exchange:
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
# After some processing, the data will later be maintained by `quote_cls` object for faster data retriving.
# Some conventions for `quote_df`
# - $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# - $factor is for rounding to the trading unit;
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
quote_df: pd.DataFrame
def __init__(
self,
freq: str = "day",
@@ -159,6 +168,7 @@ class Exchange:
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
@@ -199,7 +209,7 @@ class Exchange:
self.end_time,
freq=self.freq,
disk_cache=True,
).dropna(subset=["$close"])
)
self.quote_df.columns = self.all_fields
# check buy_price data and sell_price data
@@ -209,7 +219,7 @@ class Exchange:
self.logger.warning("{} field data contains nan.".format(pstr))
# update trade_w_adj_price
if self.quote_df["$factor"].isna().any():
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
@@ -245,9 +255,9 @@ class Exchange:
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
LT_NONE = "none" # none
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
LT_NONE = "none" # none: there is no trading limitation
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
@@ -261,20 +271,25 @@ class Exchange:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
# $close is may contains NaN, the nan indicates that the stock is not tradable at that timestamp
suspended = self.quote_df["$close"].isna()
# check limit_threshold
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = False
self.quote_df["limit_sell"] = False
self.quote_df["limit_buy"] = suspended
self.quote_df["limit_sell"] = suspended
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
# astype bool is necessary, because quote_df is an expression and could be float
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
self.quote_df["limit_sell"] = (
self.quote_df["$change"].le(-limit_threshold) | suspended
) # pylint: disable=E1130
@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
@@ -338,8 +353,18 @@ class Exchange:
- if direction is None, check if tradable for buying and selling.
- if direction == Order.BUY, check the if tradable for buying
- if direction == Order.SELL, check the sell limit for selling.
Returns
-------
True: the trading of the stock is limted (maybe hit the highest/lowest price), hence the stock is not tradable
False: the trading of the stock is not limited, hence the stock may be tradable
"""
# NOTE:
# **all** is used when checking limitation.
# For example, the stock trading is limited in a day if every miniute is limited in a day if every miniute is limited.
if direction is None:
# The trading limitation is related to the trading direction
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return bool(buy_limit or sell_limit)
@@ -356,10 +381,24 @@ class Exchange:
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> bool:
"""if stock is suspended(hence not tradable), True will be returned"""
# is suspended
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
# suspended stocks are represented by None $close stock
# The $close may contains NaN,
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
if close is None:
# if no close record exists
return True
elif isinstance(close, IndexData):
# **any** non-NaN $close represents trading opportunity may exists
# if all returned is nan, then the stock is suspended
return cast(bool, cast(IndexData, close).isna().all())
else:
# it is single value, make sure is is not None
return np.isnan(close)
else:
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
return True
def is_stock_tradable(

View File

@@ -8,23 +8,22 @@ import os
import pickle
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, cast
import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed
from qlib.typehint import Literal
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
from qlib.typehint import Literal
def _get_multi_level_executor_config(
@@ -61,15 +60,6 @@ def _get_multi_level_executor_config(
return executor_config
def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
if isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
env = CollectDataEnvWrapper()
env.reset()
executor.inner_strategy.set_env(env)
_set_env_for_all_strategy(executor.inner_executor)
def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
record_list = []
for time, value_dict in indicator.items():
@@ -94,9 +84,10 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
return records
# TODO: there should be richer annotation for the input (e.g. report) and the returned report
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
def _generate_report(
decisions: List[BaseTradeDecision],
report_indicators: List[INDICATOR_METRIC],
) -> Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Generate backtest reports
Parameters
@@ -109,28 +100,25 @@ def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List
-------
"""
indicator_dict = defaultdict(list)
indicator_his = defaultdict(list)
indicator_dict: Dict[str, List[pd.DataFrame]] = defaultdict(list)
indicator_his: Dict[str, List[dict]] = defaultdict(list)
for report_indicator in report_indicators:
for key, value in report_indicator.items():
if key.endswith("_obj"):
indicator_his[key].append(value.order_indicator_his)
else:
indicator_dict[key].append(value)
for key, (indicator_df, indicator_obj) in report_indicator.items():
indicator_dict[key].append(indicator_df)
indicator_his[key].append(indicator_obj.order_indicator_his)
report = {}
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
for key in ["1min", "5min", "30min", "1day"]:
if key not in indicator_dict:
continue
report[key] = pd.concat(indicator_dict[key])
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])
for key in indicator_dict:
cur_dict = pd.concat(indicator_dict[key])
cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]])
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
if len(cur_details) > 0:
cur_details.pop("freq")
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
cur_his = cur_his.join(cur_details, how="outer")
report[key] = (cur_dict, cur_his)
return report
@@ -209,25 +197,25 @@ def single_with_simulator(
exchange_config=exchange_config,
qlib_config=None,
cash_limit=None,
backtest_mode=True,
)
reports.append(simulator.report_dict)
decisions += simulator.decisions
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator)
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
report = _generate_report(decisions, [report["indicator"] for report in reports])
_report = _generate_report(decisions, [report["indicator"] for report in reports])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: report}
report = {day: _report}
return records, report
else:
@@ -312,22 +300,22 @@ def single_with_collect_data_loop(
exchange_kwargs=exchange_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
_set_env_for_all_strategy(executor=executor)
report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
report = _generate_report(decisions, [report_dict["indicator"]])
_report = _generate_report(decisions, [indicator_dict])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: report}
report = {day: _report}
return records, report
else:
return records
@@ -337,7 +325,7 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
order_df = read_order_file(backtest_config["order_file"])
cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config["exchange"].pop("generate_report")
generate_report = backtest_config.pop("generate_report")
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
@@ -382,9 +370,19 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend")
parser.add_argument(
"--n_jobs",
type=int,
required=False,
help="The number of jobs for running backtest parallely(1 for single process)",
)
args = parser.parse_args()
config = get_backtest_config_fromfile(args.config_path)
if args.n_jobs is not None:
config["concurrency"] = args.n_jobs
backtest(
backtest_config=get_backtest_config_fromfile(args.config_path),
backtest_config=config,
with_simulator=args.use_simulator,
)

View File

@@ -11,11 +11,14 @@ from importlib import import_module
import yaml
DELETE_KEY = "_delete_"
def merge_a_into_b(a: dict, b: dict) -> dict:
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b:
v.pop("_delete_", False) # TODO: make this more elegant
v.pop(DELETE_KEY, False)
b[k] = merge_a_into_b(v, b[k])
else:
b[k] = v
@@ -86,7 +89,6 @@ def get_backtest_config_fromfile(path: str) -> dict:
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
"generate_report": False,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
@@ -97,7 +99,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
# "runtime": {},
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

View File

@@ -13,7 +13,6 @@ from qlib.rl.order_execution.utils import get_ticks_slice
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
from ...data import D
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -81,17 +80,7 @@ def load_backtest_data(
trade_exchange: Exchange,
trade_range: TradeRange,
) -> IntradayBacktestData:
# TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
tmp_data = D.features(
trade_exchange.codes,
trade_exchange.all_fields,
trade_exchange.start_time,
trade_exchange.end_time,
freq=trade_exchange.freq,
disk_cache=True,
)
ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"])
ticks_index = pd.DatetimeIndex(trade_exchange.quote_df.reset_index()["datetime"])
ticks_index = ticks_index[order.start_time <= ticks_index]
ticks_index = ticks_index[ticks_index <= order.end_time]

View File

@@ -3,19 +3,15 @@
from __future__ import annotations
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
from typing import Any, Generic, TypeVar
import gym
import numpy as np
from gym import spaces
from qlib.typehint import final
from .simulator import ActType, StateType
if TYPE_CHECKING:
from .utils.env_wrapper import BaseEnvWrapper
import gym
from gym import spaces
ObsType = TypeVar("ObsType")
PolicyActType = TypeVar("PolicyActType")
@@ -39,8 +35,6 @@ class Interpreter:
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: Optional[BaseEnvWrapper] = None
@property
def observation_space(self) -> gym.Space:
raise NotImplementedError()
@@ -73,8 +67,6 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
env: Optional[BaseEnvWrapper] = None
@property
def action_space(self) -> gym.Space:
raise NotImplementedError()

View File

@@ -69,8 +69,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
Provider of the processed data.
"""
# TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case.
# TODO: So it should be redesigned after the data interface is well-designed.
def __init__(
self,
max_step: int,
@@ -78,6 +76,8 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
data_dim: int,
processed_data_provider: dict | ProcessedDataProvider,
) -> None:
super().__init__()
self.max_step = max_step
self.data_ticks = data_ticks
self.data_dim = data_dim
@@ -87,10 +87,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
)
def interpret(self, state: SAOEState) -> FullHistoryObs:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
processed = self.processed_data_provider.get_data(
stock_id=state.order.stock_id,
date=pd.Timestamp(state.order.start_time.date()),
@@ -102,8 +98,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
position_history[0] = state.order.amount
position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
assert self.env is not None
# The min, slice here are to make sure that indices fit into the range,
# even after the final step of the simulator (in the done step),
# to make network in policy happy.
@@ -115,7 +109,7 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"data_processed_prev": np.array(processed.yesterday),
"acquiring": _to_int32(state.order.direction == state.order.BUY),
"cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),
"cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)),
"cur_step": _to_int32(min(state.cur_step, self.max_step - 1)),
"num_step": _to_int32(self.max_step),
"target": _to_float32(state.order.amount),
"position": _to_float32(state.position),
@@ -163,6 +157,8 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
"""
def __init__(self, max_step: int) -> None:
super().__init__()
self.max_step = max_step
@property
@@ -177,15 +173,10 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
return spaces.Dict(space)
def interpret(self, state: SAOEState) -> CurrentStateObs:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
assert state.cur_step <= self.max_step
obs = CurrentStateObs(
acquiring=state.order.direction == state.order.BUY,
cur_step=self.env.status["cur_step"],
cur_step=state.cur_step,
num_step=self.max_step,
target=state.order.amount,
position=state.position,
@@ -208,6 +199,8 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
"""
def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None:
super().__init__()
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@@ -218,13 +211,8 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
return spaces.Discrete(len(self.action_values))
def interpret(self, state: SAOEState, action: int) -> float:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert 0 <= action < len(self.action_values)
assert self.env is not None
if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1:
if self.max_step is not None and state.cur_step >= self.max_step - 1:
return state.position
else:
return min(state.position, state.order.amount * self.action_values[action])
@@ -244,13 +232,8 @@ class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
def interpret(self, state: SAOEState, action: float) -> float:
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
# way to decompose interpreter and EnvWrapper in the future.
assert self.env is not None
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
twap_volume = state.position / (estimated_total_steps - state.cur_step)
return min(state.position, twap_volume * action)

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast
import gym
import numpy as np
@@ -14,6 +14,8 @@ from gym.spaces import Discrete
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
from qlib.rl.trainer.trainer import Trainer
__all__ = ["AllOne", "PPO"]
@@ -148,7 +150,7 @@ class PPO(PPOPolicy):
action_space=action_space,
)
if weight_file is not None:
load_weight(self, weight_file)
set_weight(self, Trainer.get_policy_state_dict(weight_file))
# utilities: these should be put in a separate (common) file. #
@@ -160,15 +162,7 @@ def auto_device(module: nn.Module) -> torch.device:
return torch.device("cpu") # fallback to cpu
def load_weight(policy: nn.Module, path: Path) -> None:
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
# TODO: this should be handled by whoever calls load_weight.
# TODO: For example, when the outer class receives a weight, it should first unpack it,
# TODO: and send the corresponding part to individual component.
if "vessel" in loaded_weight:
loaded_weight = loaded_weight["vessel"]["policy"]
def set_weight(policy: nn.Module, loaded_weight: OrderedDict) -> None:
try:
policy.load_state_dict(loaded_weight)
except RuntimeError:

View File

@@ -9,12 +9,11 @@ import pandas as pd
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor
from qlib.backtest.executor import NestedExecutor
from qlib.rl.data.integration import init_qlib
from qlib.rl.simulator import Simulator
from .state import SAOEState, SAOEStateAdapter
from .strategy import SAOEStrategy
from ..utils.env_wrapper import CollectDataEnvWrapper
from .state import SAOEState
from .strategy import SAOEStateAdapter, SAOEStrategy
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
@@ -32,8 +31,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.
cash_limit:
Cash limit.
backtest_mode
Whether the simulator is under backtest mode.
"""
def __init__(
@@ -43,7 +40,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
backtest_mode: bool = False,
) -> None:
super().__init__(initial=order)
@@ -59,7 +55,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
}
self._collect_data_loop: Optional[Generator] = None
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode)
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit)
def reset(
self,
@@ -69,7 +65,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
backtest_mode: bool = False,
) -> None:
if qlib_config is not None:
init_qlib(qlib_config, part="skip")
@@ -98,16 +93,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
)
assert isinstance(self._collect_data_loop, Generator)
# TODO: backtest_mode is not a necessary parameter if we carefully design it.
# TODO: It should disappear with CollectDataEnvWrapper in the future.
if backtest_mode:
executor: BaseExecutor = self._executor
while isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
executor.inner_strategy.set_env(CollectDataEnvWrapper())
executor = executor.inner_executor
# Call `step()` with None action to initialize the internal generator.
self.step(action=None)
self._order = order

View File

@@ -16,8 +16,6 @@ from qlib.rl.utils import LogLevel
from .state import SAOEMetrics, SAOEState
# TODO: Integrating Qlib's native data with simulator_simple
__all__ = ["SingleAssetOrderExecutionSimple"]
@@ -98,6 +96,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
self.cur_time = self.ticks_for_order[0]
self.cur_step = 0
# NOTE: astype(float) is necessary in some systems.
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
@@ -194,11 +193,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.env.logger.add_any(key, value)
self.cur_time = self._next_time()
self.cur_step += 1
def get_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
cur_step=self.cur_step,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,

View File

@@ -4,290 +4,15 @@
from __future__ import annotations
import typing
from typing import cast, Callable, List, NamedTuple, Optional, Tuple
from typing import NamedTuple, Optional
import numpy as np
import pandas as pd
from qlib.backtest import Exchange, Order
from qlib.backtest.executor import BaseExecutor
from qlib.constant import EPS, ONE_MIN, REG_CN
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
from qlib.backtest import Order
from qlib.typehint import TypedDict
from qlib.utils.index_data import IndexData
from qlib.utils.time import get_day_min_idx_range
if typing.TYPE_CHECKING:
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import IntradayBacktestData
def _get_all_timestamps(
start: pd.Timestamp,
end: pd.Timestamp,
granularity: pd.Timedelta = ONE_MIN,
include_end: bool = True,
) -> pd.DatetimeIndex:
ret = []
while start <= end:
ret.append(start)
start += granularity
if ret[-1] > end:
ret.pop()
if ret[-1] == end and not include_end:
ret.pop()
return pd.DatetimeIndex(ret)
def fill_missing_data(
original_data: np.ndarray,
total_time_list: List[pd.Timestamp],
found_time_list: List[pd.Timestamp],
fill_method: Callable = np.median,
) -> np.ndarray:
"""Fill missing data. We need this function to deal with data that have missing values in some minutes.
TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
Parameters
----------
original_data
Original data without missing values.
total_time_list
All timestamps that required.
found_time_list
Timestamps found in the original data.
fill_method
Method used to fill the missing data.
Returns
-------
The filled data.
"""
assert len(original_data) == len(found_time_list)
tmp = dict(zip(found_time_list, original_data))
fill_val = fill_method(original_data)
return np.array([tmp.get(t, fill_val) for t in total_time_list])
class SAOEStateAdapter:
"""
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
according to the execution results with additional information acquired from executors & exchange. For example,
it gets the dealt order amount from execution results, and get the corresponding market price / volume from
exchange.
Example usage::
adapter = SAOEStateAdapter(...)
adapter.update(...)
state = adapter.saoe_state
"""
def __init__(
self,
order: Order,
executor: BaseExecutor,
exchange: Exchange,
ticks_per_step: int,
backtest_data: IntradayBacktestData,
) -> None:
self.position = order.amount
self.order = order
self.executor = executor
self.exchange = exchange
self.backtest_data = backtest_data
self.twap_price = self.backtest_data.get_deal_price().mean()
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics: Optional[SAOEMetrics] = None
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
self.ticks_per_step = ticks_per_step
def _next_time(self) -> pd.Timestamp:
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
next_loc = next_loc - next_loc % self.ticks_per_step
if (
next_loc < len(self.backtest_data.ticks_index)
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
):
return self.backtest_data.ticks_index[next_loc]
else:
return self.order.end_time
def update(
self,
execute_result: list,
last_step_range: Tuple[int, int],
) -> None:
last_step_size = last_step_range[1] - last_step_range[0] + 1
start_time = self.backtest_data.ticks_index[last_step_range[0]]
end_time = self.backtest_data.ticks_index[last_step_range[1]]
exec_vol = np.zeros(last_step_size)
for order, _, __, ___ in execute_result:
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
exec_vol[idx - last_step_range[0]] = order.deal_amount
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
exec_vol *= self.position / (exec_vol.sum())
market_volume = cast(
IndexData,
self.exchange.get_volume(
self.order.stock_id,
pd.Timestamp(start_time),
pd.Timestamp(end_time),
method=None,
),
)
market_price = cast(
IndexData,
self.exchange.get_deal_price(
self.order.stock_id,
pd.Timestamp(start_time),
pd.Timestamp(end_time),
method=None,
direction=self.order.direction,
),
)
found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)]
total_time_list = _get_all_timestamps(start_time, end_time)
market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list)
market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list)
assert market_price.shape == market_volume.shape == exec_vol.shape
# Get data from the current level executor's indicator
current_trade_account = self.executor.trade_account
current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
self.history_exec = dataframe_append(
self.history_exec,
self._collect_multi_order_metric(
order=self.order,
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
market_vol=market_volume,
market_price=market_price,
exec_vol=exec_vol,
pa=current_df.iloc[-1]["pa"],
),
)
self.history_steps = dataframe_append(
self.history_steps,
[
self._collect_single_order_metric(
self.order,
self.cur_time,
market_volume,
market_price,
exec_vol.sum(),
exec_vol,
),
],
)
# TODO: check whether we need this. Can we get this information from Account?
# Do this at the end
self.position -= exec_vol.sum()
self.cur_time = self._next_time()
def generate_metrics_after_done(self) -> None:
"""Generate metrics once the upper level execution is done"""
self.metrics = self._collect_single_order_metric(
self.order,
self.backtest_data.ticks_index[0], # start time
self.history_exec["market_volume"],
self.history_exec["market_price"],
self.history_steps["amount"].sum(),
self.history_exec["deal_amount"],
)
def _collect_multi_order_metric(
self,
order: Order,
datetime: pd.DatetimeIndex,
market_vol: np.ndarray,
market_price: np.ndarray,
exec_vol: np.ndarray,
pa: float,
) -> SAOEMetrics:
return SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol,
market_price=market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=market_price,
trade_value=market_price * exec_vol,
position=self.position - np.cumsum(exec_vol),
ffr=exec_vol / order.amount,
pa=pa,
)
def _collect_single_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
amount: float, # intended to trade such amount
exec_vol: np.ndarray,
) -> SAOEMetrics:
assert len(market_vol) == len(market_price) == len(exec_vol)
if np.abs(np.sum(exec_vol)) < EPS:
exec_avg_price = 0.0
else:
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
exec_sum = exec_vol.sum()
return SAOEMetrics(
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol.sum(),
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
amount=amount,
inner_amount=exec_sum,
deal_amount=exec_sum, # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position - exec_sum,
ffr=float(exec_sum / order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, order.direction),
)
@property
def saoe_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
metrics=self.metrics,
backtest_data=self.backtest_data,
ticks_per_step=self.ticks_per_step,
ticks_index=self.backtest_data.ticks_index,
ticks_for_order=self.backtest_data.ticks_for_order,
)
class SAOEMetrics(TypedDict):
@@ -302,7 +27,7 @@ class SAOEMetrics(TypedDict):
stock_id: str
"""Stock ID of this record."""
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
datetime: pd.Timestamp | pd.DatetimeIndex
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
@@ -349,6 +74,8 @@ class SAOEState(NamedTuple):
"""The order we are dealing with."""
cur_time: pd.Timestamp
"""Current time, e.g., 9:30."""
cur_step: int
"""Current step, e.g., 0."""
position: float
"""Current remaining volume to execute."""
history_exec: pd.DataFrame

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import collections
from types import GeneratorType
from typing import Any, cast, Dict, Generator, List, Optional, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
@@ -15,14 +15,276 @@ from tianshou.policy import BasePolicy
from qlib.backtest import CommonInfrastructure, Order
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange
from qlib.backtest.utils import LevelInfrastructure
from qlib.constant import ONE_MIN
from qlib.rl.data.native import load_backtest_data
from qlib.backtest.exchange import Exchange
from qlib.backtest.executor import BaseExecutor
from qlib.backtest.utils import LevelInfrastructure, get_start_end_idx
from qlib.constant import EPS, ONE_MIN, REG_CN
from qlib.rl.data.native import IntradayBacktestData, load_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter
from qlib.rl.utils.env_wrapper import BaseEnvWrapper
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
from qlib.strategy.base import RLStrategy
from qlib.utils import init_instance_by_config
from qlib.utils.index_data import IndexData
from qlib.utils.time import get_day_min_idx_range
def _get_all_timestamps(
start: pd.Timestamp,
end: pd.Timestamp,
granularity: pd.Timedelta = ONE_MIN,
include_end: bool = True,
) -> pd.DatetimeIndex:
ret = []
while start <= end:
ret.append(start)
start += granularity
if ret[-1] > end:
ret.pop()
if ret[-1] == end and not include_end:
ret.pop()
return pd.DatetimeIndex(ret)
def fill_missing_data(
original_data: np.ndarray,
fill_method: Callable = np.nanmedian,
) -> np.ndarray:
"""Fill missing data.
Parameters
----------
original_data
Original data without missing values.
fill_method
Method used to fill the missing data.
Returns
-------
The filled data.
"""
return np.nan_to_num(original_data, nan=fill_method(original_data))
class SAOEStateAdapter:
"""
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
according to the execution results with additional information acquired from executors & exchange. For example,
it gets the dealt order amount from execution results, and get the corresponding market price / volume from
exchange.
Example usage::
adapter = SAOEStateAdapter(...)
adapter.update(...)
state = adapter.saoe_state
"""
def __init__(
self,
order: Order,
trade_decision: BaseTradeDecision,
executor: BaseExecutor,
exchange: Exchange,
ticks_per_step: int,
backtest_data: IntradayBacktestData,
) -> None:
self.position = order.amount
self.order = order
self.executor = executor
self.exchange = exchange
self.backtest_data = backtest_data
self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision)
self.twap_price = self.backtest_data.get_deal_price().mean()
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics: Optional[SAOEMetrics] = None
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
self.ticks_per_step = ticks_per_step
def _next_time(self) -> pd.Timestamp:
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
next_loc = next_loc - next_loc % self.ticks_per_step
if (
next_loc < len(self.backtest_data.ticks_index)
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
):
return self.backtest_data.ticks_index[next_loc]
else:
return self.order.end_time
def update(
self,
execute_result: list,
last_step_range: Tuple[int, int],
) -> None:
last_step_size = last_step_range[1] - last_step_range[0] + 1
start_time = self.backtest_data.ticks_index[last_step_range[0]]
end_time = self.backtest_data.ticks_index[last_step_range[1]]
exec_vol = np.zeros(last_step_size)
for order, _, __, ___ in execute_result:
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
exec_vol[idx - last_step_range[0]] = order.deal_amount
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
exec_vol *= self.position / (exec_vol.sum())
market_volume = cast(
IndexData,
self.exchange.get_volume(
self.order.stock_id,
pd.Timestamp(start_time),
pd.Timestamp(end_time),
method=None,
),
)
market_price = cast(
IndexData,
self.exchange.get_deal_price(
self.order.stock_id,
pd.Timestamp(start_time),
pd.Timestamp(end_time),
method=None,
direction=self.order.direction,
),
)
market_price = fill_missing_data(np.array(market_price, dtype=float).reshape(-1))
market_volume = fill_missing_data(np.array(market_volume, dtype=float).reshape(-1))
assert market_price.shape == market_volume.shape == exec_vol.shape
# Get data from the current level executor's indicator
current_trade_account = self.executor.trade_account
current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
self.history_exec = dataframe_append(
self.history_exec,
self._collect_multi_order_metric(
order=self.order,
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
market_vol=market_volume,
market_price=market_price,
exec_vol=exec_vol,
pa=current_df.iloc[-1]["pa"],
),
)
self.history_steps = dataframe_append(
self.history_steps,
[
self._collect_single_order_metric(
self.order,
self.cur_time,
market_volume,
market_price,
exec_vol.sum(),
exec_vol,
),
],
)
# Do this at the end
self.position -= exec_vol.sum()
self.cur_time = self._next_time()
def generate_metrics_after_done(self) -> None:
"""Generate metrics once the upper level execution is done"""
self.metrics = self._collect_single_order_metric(
self.order,
self.backtest_data.ticks_index[0], # start time
self.history_exec["market_volume"],
self.history_exec["market_price"],
self.history_steps["amount"].sum(),
self.history_exec["deal_amount"],
)
def _collect_multi_order_metric(
self,
order: Order,
datetime: pd.DatetimeIndex,
market_vol: np.ndarray,
market_price: np.ndarray,
exec_vol: np.ndarray,
pa: float,
) -> SAOEMetrics:
return SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol,
market_price=market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=market_price,
trade_value=market_price * exec_vol,
position=self.position - np.cumsum(exec_vol),
ffr=exec_vol / order.amount,
pa=pa,
)
def _collect_single_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
amount: float, # intended to trade such amount
exec_vol: np.ndarray,
) -> SAOEMetrics:
assert len(market_vol) == len(market_price) == len(exec_vol)
if np.abs(np.sum(exec_vol)) < EPS:
exec_avg_price = 0.0
else:
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
exec_sum = exec_vol.sum()
return SAOEMetrics(
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol.sum(),
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
amount=amount,
inner_amount=exec_sum,
deal_amount=exec_sum, # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position - exec_sum,
ffr=float(exec_sum / order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, order.direction),
)
@property
def saoe_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
cur_step=self.executor.trade_calendar.get_trade_step() - self.start_idx,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
metrics=self.metrics,
backtest_data=self.backtest_data,
ticks_per_step=self.ticks_per_step,
ticks_index=self.backtest_data.ticks_index,
ticks_for_order=self.backtest_data.ticks_for_order,
)
class SAOEStrategy(RLStrategy):
@@ -30,7 +292,7 @@ class SAOEStrategy(RLStrategy):
def __init__(
self,
policy: object, # TODO: add accurate typehint later.
policy: BasePolicy,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
@@ -47,11 +309,17 @@ class SAOEStrategy(RLStrategy):
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
self._last_step_range = (0, 0)
def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
def _create_qlib_backtest_adapter(
self,
order: Order,
trade_decision: BaseTradeDecision,
trade_range: TradeRange,
) -> SAOEStateAdapter:
backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)
return SAOEStateAdapter(
order=order,
trade_decision=trade_decision,
executor=self.executor,
exchange=self.trade_exchange,
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
@@ -71,7 +339,9 @@ class SAOEStrategy(RLStrategy):
self.adapter_dict = {}
for decision in outer_trade_decision.get_decision():
order = cast(Order, decision)
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range)
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(
order, outer_trade_decision, trade_range
)
def get_saoe_state_by_order(self, order: Order) -> SAOEState:
return self.adapter_dict[order.key_by_day].saoe_state
@@ -166,11 +436,10 @@ class SAOEIntStrategy(SAOEStrategy):
policy: dict | BasePolicy,
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: object = None, # TODO: add accurate typehint later.
network: dict | torch.nn.Module | None = None,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
backtest: bool = False,
**kwargs: Any,
) -> None:
super(SAOEIntStrategy, self).__init__(
@@ -181,8 +450,6 @@ class SAOEIntStrategy(SAOEStrategy):
**kwargs,
)
self._backtest = backtest
self._state_interpreter: StateInterpreter = init_instance_by_config(
state_interpreter,
accept_types=StateInterpreter,
@@ -221,21 +488,9 @@ class SAOEIntStrategy(SAOEStrategy):
if self._policy is not None:
self._policy.eval()
def set_env(self, env: BaseEnvWrapper) -> None:
# TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper.
# We should decompose the interpreters with EnvWrapper in the future and we should remove this method
# after that.
self._env = env
self._state_interpreter.env = self._action_interpreter.env = self._env
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
# In backtest, env.reset() needs to be manually called since there is no outer trainer to call it
if self._backtest:
self._env.reset()
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
assert hasattr(self.outer_trade_decision, "order_list")
@@ -268,10 +523,6 @@ class SAOEIntStrategy(SAOEStrategy):
act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act
exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]
# In backtest, env.step() needs to be manually called since there is no outer trainer to call it
if self._backtest:
self._env.step(None)
oh = self.trade_exchange.get_order_helper()
order_list = []
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):

View File

@@ -7,7 +7,7 @@ import collections
import copy
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
import torch
@@ -152,6 +152,13 @@ class Trainer:
"metrics": self.metrics,
}
@staticmethod
def get_policy_state_dict(ckpt_path: Path) -> OrderedDict:
state_dict = torch.load(ckpt_path, map_location="cpu")
if "vessel" in state_dict:
state_dict = state_dict["vessel"]["policy"]
return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""Load all states into current trainer."""
self.vessel.load_state_dict(state_dict["vessel"])

View File

@@ -48,24 +48,9 @@ class EnvWrapperStatus(TypedDict):
reward_history: list
class BaseEnvWrapper(
class EnvWrapper(
gym.Env[ObsType, PolicyActType],
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Base env wrapper for RL environments. It has two implementations:
- EnvWrapper: Qlib-based RL environment used in training.
- CollectDataEnvWrapper: Dummy environment used in collect_data_loop.
"""
def __init__(self) -> None:
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in BaseEnvWrapper.")
class EnvWrapper(
BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
@@ -129,8 +114,6 @@ class EnvWrapper(
# 3. Avoid circular reference.
# 4. When the components get serialized, we can throw away the env without any burden.
# (though this part is not implemented yet)
super().__init__()
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
if obj is not None:
obj.env = weakref.proxy(self) # type: ignore
@@ -263,19 +246,5 @@ class EnvWrapper(
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
"""Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop."""
def reset(self, **kwargs: Any) -> None:
self.status = EnvWrapperStatus(
cur_step=0,
done=False,
initial_state=None,
obs_history=[],
action_history=[],
reward_history=[],
)
def step(self, policy_action: Any = None, **kwargs: Any) -> None:
self.status["cur_step"] += 1
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in EnvWrapper.")

View File

@@ -473,7 +473,8 @@ class PortAnaRecord(ACRecordTemp):
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -511,7 +512,7 @@ class PortAnaRecord(ACRecordTemp):
if _analysis_freq not in indicator_dict:
warnings.warn(f"the freq {_analysis_freq} indicator is not found")
else:
indicators_normal = indicator_dict.get(_analysis_freq)
indicators_normal = indicator_dict.get(_analysis_freq)[0]
if self.indicator_analysis_method is None:
analysis_df = indicator_analysis(indicators_normal)
else:

View File

@@ -107,7 +107,7 @@ class FileStrTest(TestAutoData):
)
# ffr valid
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
ffr_dict = indicator_dict["1day"][0]["ffr"].to_dict()
ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict}
assert np.isclose(ffr_dict["2020-01-03"], dealt_num_for_1000 / 1000)
assert np.isclose(ffr_dict["2020-01-06"], 0)

View File

@@ -125,7 +125,7 @@ class TestHFBacktest(TestAutoData):
# NOTE: please refer to the docs of format_decisions
# NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!!
f_dec = format_decisions(decisions)
print(indicator["1day"])
print(indicator["1day"][0])
if __name__ == "__main__":

View File

@@ -7,11 +7,11 @@ from typing import Tuple
import pandas as pd
import pytest
from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
from qlib.backtest.decision import Order, OrderDir
from qlib.backtest.executor import SimulatorExecutor
from qlib.rl.order_execution import CategoricalActionInterpreter
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
TOTAL_POSITION = 2100.0
@@ -183,8 +183,6 @@ def test_interpreter() -> None:
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
interpreter_action.env = CollectDataEnvWrapper()
interpreter_action.env.reset()
NUM_STEPS = 7
state = simulator.get_state()

View File

@@ -20,7 +20,6 @@ from qlib.rl.data.pickle_styled import PickleProcessedDataProvider
from qlib.rl.order_execution import *
from qlib.rl.trainer import backtest, train
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
@@ -186,10 +185,6 @@ def test_interpreter():
assert np.sum(obs["data_processed"][60:]) == 0
# second step: action
interpreter_action.env = CollectDataEnvWrapper()
interpreter_action_twap.env = CollectDataEnvWrapper()
interpreter_action.env.reset()
interpreter_action_twap.env.reset()
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20
@@ -260,8 +255,6 @@ def test_twap_strategy(finite_env_type):
state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
action_interp = TwapRelativeActionInterpreter()
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")
@@ -291,8 +284,6 @@ def test_cn_ppo_strategy():
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
@@ -324,8 +315,6 @@ def test_ppo_train():
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.env = CollectDataEnvWrapper()
action_interp.env.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)