mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
Refine RL todos (#1332)
* Refine several todos * CI issues * Remove Dropna limitation of `quote_df` in Exchange (#1334) * Remove Dropna limitation of `quote_df` of Exchange * Impreove docstring * Fix type error when expression is specified (#1335) * Refine fill_missing_data() * Remove several TODO comments * Add back env for interpreters * Change Literal import * Resolve PR comments * Move to SAOEState * Add Trainer.get_policy_state_dict() * Mypy issue Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
@@ -20,7 +19,7 @@ if TYPE_CHECKING:
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .utils import CommonInfrastructure
|
||||
@@ -223,7 +222,7 @@ def backtest(
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
|
||||
executor in the nested decision execution
|
||||
|
||||
@@ -256,9 +255,9 @@ def backtest(
|
||||
|
||||
Returns
|
||||
-------
|
||||
portfolio_metrics_dict: Dict[PortfolioMetrics]
|
||||
portfolio_dict: PORT_METRIC
|
||||
it records the trading portfolio_metrics information
|
||||
indicator_dict: Dict[Indicator]
|
||||
indicator_dict: INDICATOR_METRIC
|
||||
it computes the trading indicator
|
||||
It is organized in a dict format
|
||||
|
||||
@@ -273,8 +272,7 @@ def backtest(
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
return portfolio_metrics, indicator
|
||||
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
|
||||
def collect_data(
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from qlib.backtest.report import Indicator, PortfolioMetrics
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
@@ -19,30 +19,35 @@ from tqdm.auto import tqdm
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
|
||||
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
|
||||
|
||||
|
||||
def backtest_loop(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
|
||||
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
|
||||
Returns
|
||||
-------
|
||||
portfolio_metrics: PortfolioMetrics
|
||||
portfolio_dict: PORT_METRIC
|
||||
it records the trading portfolio_metrics information
|
||||
indicator: Indicator
|
||||
indicator_dict: INDICATOR_METRIC
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value: dict = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
|
||||
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||
indicator = cast(Indicator, return_value.get("indicator"))
|
||||
return portfolio_metrics, indicator
|
||||
portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict"))
|
||||
indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict"))
|
||||
|
||||
return portfolio_dict, indicator_dict
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
@@ -89,14 +94,17 @@ def collect_data_loop(
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
|
||||
|
||||
portfolio_dict: PORT_METRIC = {}
|
||||
indicator_dict: INDICATOR_METRIC = {}
|
||||
|
||||
for executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(executor.time_per_step))
|
||||
if executor.trade_account.is_port_metr_enabled():
|
||||
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()
|
||||
|
||||
indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
indicator_obj = executor.trade_account.get_trade_indicator()
|
||||
indicator_dict[key] = (indicator_df, indicator_obj)
|
||||
|
||||
return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})
|
||||
|
||||
@@ -26,6 +26,15 @@ from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
|
||||
# After some processing, the data will later be maintained by `quote_cls` object for faster data retriving.
|
||||
# Some conventions for `quote_df`
|
||||
# - $close is for calculating the total value at end of each day.
|
||||
# - if $close is None, the stock on that day is reguarded as suspended.
|
||||
# - $factor is for rounding to the trading unit;
|
||||
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
|
||||
quote_df: pd.DataFrame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
freq: str = "day",
|
||||
@@ -159,6 +168,7 @@ class Exchange:
|
||||
self.codes = codes
|
||||
# Necessary fields
|
||||
# $close is for calculating the total value at end of each day.
|
||||
# - if $close is None, the stock on that day is reguarded as suspended.
|
||||
# $factor is for rounding to the trading unit
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
@@ -199,7 +209,7 @@ class Exchange:
|
||||
self.end_time,
|
||||
freq=self.freq,
|
||||
disk_cache=True,
|
||||
).dropna(subset=["$close"])
|
||||
)
|
||||
self.quote_df.columns = self.all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
@@ -209,7 +219,7 @@ class Exchange:
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
|
||||
# update trade_w_adj_price
|
||||
if self.quote_df["$factor"].isna().any():
|
||||
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
# Use adjusted price
|
||||
self.trade_w_adj_price = True
|
||||
@@ -245,9 +255,9 @@ class Exchange:
|
||||
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
|
||||
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
|
||||
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
|
||||
LT_NONE = "none" # none: there is no trading limitation
|
||||
|
||||
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
@@ -261,20 +271,25 @@ class Exchange:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
|
||||
# $close is may contains NaN, the nan indicates that the stock is not tradable at that timestamp
|
||||
suspended = self.quote_df["$close"].isna()
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
self.quote_df["limit_buy"] = False
|
||||
self.quote_df["limit_sell"] = False
|
||||
self.quote_df["limit_buy"] = suspended
|
||||
self.quote_df["limit_sell"] = suspended
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
limit_threshold = cast(tuple, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||
# astype bool is necessary, because quote_df is an expression and could be float
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
|
||||
elif limit_type == self.LT_FLT:
|
||||
limit_threshold = cast(float, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
|
||||
self.quote_df["limit_sell"] = (
|
||||
self.quote_df["$change"].le(-limit_threshold) | suspended
|
||||
) # pylint: disable=E1130
|
||||
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
@@ -338,8 +353,18 @@ class Exchange:
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
- if direction == Order.BUY, check the if tradable for buying
|
||||
- if direction == Order.SELL, check the sell limit for selling.
|
||||
|
||||
Returns
|
||||
-------
|
||||
True: the trading of the stock is limted (maybe hit the highest/lowest price), hence the stock is not tradable
|
||||
False: the trading of the stock is not limited, hence the stock may be tradable
|
||||
"""
|
||||
# NOTE:
|
||||
# **all** is used when checking limitation.
|
||||
# For example, the stock trading is limited in a day if every miniute is limited in a day if every miniute is limited.
|
||||
if direction is None:
|
||||
# The trading limitation is related to the trading direction
|
||||
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return bool(buy_limit or sell_limit)
|
||||
@@ -356,10 +381,24 @@ class Exchange:
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> bool:
|
||||
"""if stock is suspended(hence not tradable), True will be returned"""
|
||||
# is suspended
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
|
||||
# suspended stocks are represented by None $close stock
|
||||
# The $close may contains NaN,
|
||||
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
|
||||
if close is None:
|
||||
# if no close record exists
|
||||
return True
|
||||
elif isinstance(close, IndexData):
|
||||
# **any** non-NaN $close represents trading opportunity may exists
|
||||
# if all returned is nan, then the stock is suspended
|
||||
return cast(bool, cast(IndexData, close).isna().all())
|
||||
else:
|
||||
# it is single value, make sure is is not None
|
||||
return np.isnan(close)
|
||||
else:
|
||||
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
|
||||
return True
|
||||
|
||||
def is_stock_tradable(
|
||||
|
||||
@@ -8,23 +8,22 @@ import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.typehint import Literal
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
def _get_multi_level_executor_config(
|
||||
@@ -61,15 +60,6 @@ def _get_multi_level_executor_config(
|
||||
return executor_config
|
||||
|
||||
|
||||
def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
|
||||
if isinstance(executor, NestedExecutor):
|
||||
if hasattr(executor.inner_strategy, "set_env"):
|
||||
env = CollectDataEnvWrapper()
|
||||
env.reset()
|
||||
executor.inner_strategy.set_env(env)
|
||||
_set_env_for_all_strategy(executor.inner_executor)
|
||||
|
||||
|
||||
def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
|
||||
record_list = []
|
||||
for time, value_dict in indicator.items():
|
||||
@@ -94,9 +84,10 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
|
||||
return records
|
||||
|
||||
|
||||
# TODO: there should be richer annotation for the input (e.g. report) and the returned report
|
||||
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
|
||||
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
|
||||
def _generate_report(
|
||||
decisions: List[BaseTradeDecision],
|
||||
report_indicators: List[INDICATOR_METRIC],
|
||||
) -> Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]:
|
||||
"""Generate backtest reports
|
||||
|
||||
Parameters
|
||||
@@ -109,28 +100,25 @@ def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List
|
||||
-------
|
||||
|
||||
"""
|
||||
indicator_dict = defaultdict(list)
|
||||
indicator_his = defaultdict(list)
|
||||
indicator_dict: Dict[str, List[pd.DataFrame]] = defaultdict(list)
|
||||
indicator_his: Dict[str, List[dict]] = defaultdict(list)
|
||||
|
||||
for report_indicator in report_indicators:
|
||||
for key, value in report_indicator.items():
|
||||
if key.endswith("_obj"):
|
||||
indicator_his[key].append(value.order_indicator_his)
|
||||
else:
|
||||
indicator_dict[key].append(value)
|
||||
for key, (indicator_df, indicator_obj) in report_indicator.items():
|
||||
indicator_dict[key].append(indicator_df)
|
||||
indicator_his[key].append(indicator_obj.order_indicator_his)
|
||||
|
||||
report = {}
|
||||
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
|
||||
for key in ["1min", "5min", "30min", "1day"]:
|
||||
if key not in indicator_dict:
|
||||
continue
|
||||
|
||||
report[key] = pd.concat(indicator_dict[key])
|
||||
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])
|
||||
|
||||
for key in indicator_dict:
|
||||
cur_dict = pd.concat(indicator_dict[key])
|
||||
cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]])
|
||||
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
|
||||
if len(cur_details) > 0:
|
||||
cur_details.pop("freq")
|
||||
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
|
||||
cur_his = cur_his.join(cur_details, how="outer")
|
||||
|
||||
report[key] = (cur_dict, cur_his)
|
||||
|
||||
return report
|
||||
|
||||
@@ -209,25 +197,25 @@ def single_with_simulator(
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator)
|
||||
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
_report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: report}
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: report}
|
||||
report = {day: _report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
@@ -312,22 +300,22 @@ def single_with_collect_data_loop(
|
||||
exchange_kwargs=exchange_config,
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
_set_env_for_all_strategy(executor=executor)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
|
||||
records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
|
||||
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
|
||||
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
report = _generate_report(decisions, [report_dict["indicator"]])
|
||||
_report = _generate_report(decisions, [indicator_dict])
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: report}
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: report}
|
||||
report = {day: _report}
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
@@ -337,7 +325,7 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config["exchange"].pop("generate_report")
|
||||
generate_report = backtest_config.pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
@@ -382,9 +370,19 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend")
|
||||
parser.add_argument(
|
||||
"--n_jobs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="The number of jobs for running backtest parallely(1 for single process)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = get_backtest_config_fromfile(args.config_path)
|
||||
if args.n_jobs is not None:
|
||||
config["concurrency"] = args.n_jobs
|
||||
|
||||
backtest(
|
||||
backtest_config=get_backtest_config_fromfile(args.config_path),
|
||||
backtest_config=config,
|
||||
with_simulator=args.use_simulator,
|
||||
)
|
||||
|
||||
@@ -11,11 +11,14 @@ from importlib import import_module
|
||||
import yaml
|
||||
|
||||
|
||||
DELETE_KEY = "_delete_"
|
||||
|
||||
|
||||
def merge_a_into_b(a: dict, b: dict) -> dict:
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if isinstance(v, dict) and k in b:
|
||||
v.pop("_delete_", False) # TODO: make this more elegant
|
||||
v.pop(DELETE_KEY, False)
|
||||
b[k] = merge_a_into_b(v, b[k])
|
||||
else:
|
||||
b[k] = v
|
||||
@@ -86,7 +89,6 @@ def get_backtest_config_fromfile(path: str) -> dict:
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
"generate_report": False,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
@@ -97,7 +99,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs/",
|
||||
# "runtime": {},
|
||||
"generate_report": False,
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
from ...data import D
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -81,17 +80,7 @@ def load_backtest_data(
|
||||
trade_exchange: Exchange,
|
||||
trade_range: TradeRange,
|
||||
) -> IntradayBacktestData:
|
||||
# TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
|
||||
tmp_data = D.features(
|
||||
trade_exchange.codes,
|
||||
trade_exchange.all_fields,
|
||||
trade_exchange.start_time,
|
||||
trade_exchange.end_time,
|
||||
freq=trade_exchange.freq,
|
||||
disk_cache=True,
|
||||
)
|
||||
|
||||
ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"])
|
||||
ticks_index = pd.DatetimeIndex(trade_exchange.quote_df.reset_index()["datetime"])
|
||||
ticks_index = ticks_index[order.start_time <= ticks_index]
|
||||
ticks_index = ticks_index[ticks_index <= order.end_time]
|
||||
|
||||
|
||||
@@ -3,19 +3,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from qlib.typehint import final
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import BaseEnvWrapper
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
||||
ObsType = TypeVar("ObsType")
|
||||
PolicyActType = TypeVar("PolicyActType")
|
||||
|
||||
@@ -39,8 +35,6 @@ class Interpreter:
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: Optional[BaseEnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
@@ -73,8 +67,6 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: Optional[BaseEnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -69,8 +69,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
Provider of the processed data.
|
||||
"""
|
||||
|
||||
# TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case.
|
||||
# TODO: So it should be redesigned after the data interface is well-designed.
|
||||
def __init__(
|
||||
self,
|
||||
max_step: int,
|
||||
@@ -78,6 +76,8 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
data_dim: int,
|
||||
processed_data_provider: dict | ProcessedDataProvider,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_step = max_step
|
||||
self.data_ticks = data_ticks
|
||||
self.data_dim = data_dim
|
||||
@@ -87,10 +87,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
)
|
||||
|
||||
def interpret(self, state: SAOEState) -> FullHistoryObs:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
processed = self.processed_data_provider.get_data(
|
||||
stock_id=state.order.stock_id,
|
||||
date=pd.Timestamp(state.order.start_time.date()),
|
||||
@@ -102,8 +98,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
position_history[0] = state.order.amount
|
||||
position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
|
||||
|
||||
assert self.env is not None
|
||||
|
||||
# The min, slice here are to make sure that indices fit into the range,
|
||||
# even after the final step of the simulator (in the done step),
|
||||
# to make network in policy happy.
|
||||
@@ -115,7 +109,7 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
"data_processed_prev": np.array(processed.yesterday),
|
||||
"acquiring": _to_int32(state.order.direction == state.order.BUY),
|
||||
"cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),
|
||||
"cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)),
|
||||
"cur_step": _to_int32(min(state.cur_step, self.max_step - 1)),
|
||||
"num_step": _to_int32(self.max_step),
|
||||
"target": _to_float32(state.order.amount),
|
||||
"position": _to_float32(state.position),
|
||||
@@ -163,6 +157,8 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
@@ -177,15 +173,10 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
return spaces.Dict(space)
|
||||
|
||||
def interpret(self, state: SAOEState) -> CurrentStateObs:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
assert state.cur_step <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
acquiring=state.order.direction == state.order.BUY,
|
||||
cur_step=self.env.status["cur_step"],
|
||||
cur_step=state.cur_step,
|
||||
num_step=self.max_step,
|
||||
target=state.order.amount,
|
||||
position=state.position,
|
||||
@@ -208,6 +199,8 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
@@ -218,13 +211,8 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
return spaces.Discrete(len(self.action_values))
|
||||
|
||||
def interpret(self, state: SAOEState, action: int) -> float:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert 0 <= action < len(self.action_values)
|
||||
assert self.env is not None
|
||||
if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1:
|
||||
if self.max_step is not None and state.cur_step >= self.max_step - 1:
|
||||
return state.position
|
||||
else:
|
||||
return min(state.position, state.order.amount * self.action_values[action])
|
||||
@@ -244,13 +232,8 @@ class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def interpret(self, state: SAOEState, action: float) -> float:
|
||||
# TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running
|
||||
# backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant
|
||||
# way to decompose interpreter and EnvWrapper in the future.
|
||||
|
||||
assert self.env is not None
|
||||
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
|
||||
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
|
||||
twap_volume = state.position / (estimated_total_steps - state.cur_step)
|
||||
return min(state.position, twap_volume * action)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -14,6 +14,8 @@ from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
|
||||
from qlib.rl.trainer.trainer import Trainer
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
|
||||
@@ -148,7 +150,7 @@ class PPO(PPOPolicy):
|
||||
action_space=action_space,
|
||||
)
|
||||
if weight_file is not None:
|
||||
load_weight(self, weight_file)
|
||||
set_weight(self, Trainer.get_policy_state_dict(weight_file))
|
||||
|
||||
|
||||
# utilities: these should be put in a separate (common) file. #
|
||||
@@ -160,15 +162,7 @@ def auto_device(module: nn.Module) -> torch.device:
|
||||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
|
||||
# TODO: this should be handled by whoever calls load_weight.
|
||||
# TODO: For example, when the outer class receives a weight, it should first unpack it,
|
||||
# TODO: and send the corresponding part to individual component.
|
||||
if "vessel" in loaded_weight:
|
||||
loaded_weight = loaded_weight["vessel"]["policy"]
|
||||
def set_weight(policy: nn.Module, loaded_weight: OrderedDict) -> None:
|
||||
try:
|
||||
policy.load_state_dict(loaded_weight)
|
||||
except RuntimeError:
|
||||
|
||||
@@ -9,12 +9,11 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.simulator import Simulator
|
||||
from .state import SAOEState, SAOEStateAdapter
|
||||
from .strategy import SAOEStrategy
|
||||
from ..utils.env_wrapper import CollectDataEnvWrapper
|
||||
from .state import SAOEState
|
||||
from .strategy import SAOEStateAdapter, SAOEStrategy
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
@@ -32,8 +31,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
Configuration used to initialize Qlib. If it is None, Qlib will not be initialized.
|
||||
cash_limit:
|
||||
Cash limit.
|
||||
backtest_mode
|
||||
Whether the simulator is under backtest mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -43,7 +40,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
cash_limit: Optional[float] = None,
|
||||
backtest_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
@@ -59,7 +55,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
}
|
||||
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode)
|
||||
self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
@@ -69,7 +65,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
exchange_config: dict,
|
||||
qlib_config: dict = None,
|
||||
cash_limit: Optional[float] = None,
|
||||
backtest_mode: bool = False,
|
||||
) -> None:
|
||||
if qlib_config is not None:
|
||||
init_qlib(qlib_config, part="skip")
|
||||
@@ -98,16 +93,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
# TODO: backtest_mode is not a necessary parameter if we carefully design it.
|
||||
# TODO: It should disappear with CollectDataEnvWrapper in the future.
|
||||
if backtest_mode:
|
||||
executor: BaseExecutor = self._executor
|
||||
while isinstance(executor, NestedExecutor):
|
||||
if hasattr(executor.inner_strategy, "set_env"):
|
||||
executor.inner_strategy.set_env(CollectDataEnvWrapper())
|
||||
executor = executor.inner_executor
|
||||
|
||||
# Call `step()` with None action to initialize the internal generator.
|
||||
self.step(action=None)
|
||||
|
||||
self._order = order
|
||||
|
||||
@@ -16,8 +16,6 @@ from qlib.rl.utils import LogLevel
|
||||
|
||||
from .state import SAOEMetrics, SAOEState
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SingleAssetOrderExecutionSimple"]
|
||||
|
||||
|
||||
@@ -98,6 +96,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
|
||||
|
||||
self.cur_time = self.ticks_for_order[0]
|
||||
self.cur_step = 0
|
||||
# NOTE: astype(float) is necessary in some systems.
|
||||
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
|
||||
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
|
||||
@@ -194,11 +193,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.env.logger.add_any(key, value)
|
||||
|
||||
self.cur_time = self._next_time()
|
||||
self.cur_step += 1
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self.order,
|
||||
cur_time=self.cur_time,
|
||||
cur_step=self.cur_step,
|
||||
position=self.position,
|
||||
history_exec=self.history_exec,
|
||||
history_steps=self.history_steps,
|
||||
|
||||
@@ -4,290 +4,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import cast, Callable, List, NamedTuple, Optional, Tuple
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.constant import EPS, ONE_MIN, REG_CN
|
||||
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
|
||||
from qlib.backtest import Order
|
||||
from qlib.typehint import TypedDict
|
||||
from qlib.utils.index_data import IndexData
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import IntradayBacktestData
|
||||
|
||||
|
||||
def _get_all_timestamps(
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
granularity: pd.Timedelta = ONE_MIN,
|
||||
include_end: bool = True,
|
||||
) -> pd.DatetimeIndex:
|
||||
ret = []
|
||||
while start <= end:
|
||||
ret.append(start)
|
||||
start += granularity
|
||||
|
||||
if ret[-1] > end:
|
||||
ret.pop()
|
||||
if ret[-1] == end and not include_end:
|
||||
ret.pop()
|
||||
return pd.DatetimeIndex(ret)
|
||||
|
||||
|
||||
def fill_missing_data(
|
||||
original_data: np.ndarray,
|
||||
total_time_list: List[pd.Timestamp],
|
||||
found_time_list: List[pd.Timestamp],
|
||||
fill_method: Callable = np.median,
|
||||
) -> np.ndarray:
|
||||
"""Fill missing data. We need this function to deal with data that have missing values in some minutes.
|
||||
|
||||
TODO: making exchange return data without missing will make it more elegant. Fix this in the future.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
original_data
|
||||
Original data without missing values.
|
||||
total_time_list
|
||||
All timestamps that required.
|
||||
found_time_list
|
||||
Timestamps found in the original data.
|
||||
fill_method
|
||||
Method used to fill the missing data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The filled data.
|
||||
"""
|
||||
assert len(original_data) == len(found_time_list)
|
||||
tmp = dict(zip(found_time_list, original_data))
|
||||
fill_val = fill_method(original_data)
|
||||
return np.array([tmp.get(t, fill_val) for t in total_time_list])
|
||||
|
||||
|
||||
class SAOEStateAdapter:
|
||||
"""
|
||||
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
|
||||
according to the execution results with additional information acquired from executors & exchange. For example,
|
||||
it gets the dealt order amount from execution results, and get the corresponding market price / volume from
|
||||
exchange.
|
||||
|
||||
Example usage::
|
||||
|
||||
adapter = SAOEStateAdapter(...)
|
||||
adapter.update(...)
|
||||
state = adapter.saoe_state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
executor: BaseExecutor,
|
||||
exchange: Exchange,
|
||||
ticks_per_step: int,
|
||||
backtest_data: IntradayBacktestData,
|
||||
) -> None:
|
||||
self.position = order.amount
|
||||
self.order = order
|
||||
self.executor = executor
|
||||
self.exchange = exchange
|
||||
self.backtest_data = backtest_data
|
||||
|
||||
self.twap_price = self.backtest_data.get_deal_price().mean()
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
|
||||
self.ticks_per_step = ticks_per_step
|
||||
|
||||
def _next_time(self) -> pd.Timestamp:
|
||||
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
|
||||
next_loc = current_loc + self.ticks_per_step
|
||||
next_loc = next_loc - next_loc % self.ticks_per_step
|
||||
if (
|
||||
next_loc < len(self.backtest_data.ticks_index)
|
||||
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
|
||||
):
|
||||
return self.backtest_data.ticks_index[next_loc]
|
||||
else:
|
||||
return self.order.end_time
|
||||
|
||||
def update(
|
||||
self,
|
||||
execute_result: list,
|
||||
last_step_range: Tuple[int, int],
|
||||
) -> None:
|
||||
last_step_size = last_step_range[1] - last_step_range[0] + 1
|
||||
start_time = self.backtest_data.ticks_index[last_step_range[0]]
|
||||
end_time = self.backtest_data.ticks_index[last_step_range[1]]
|
||||
|
||||
exec_vol = np.zeros(last_step_size)
|
||||
for order, _, __, ___ in execute_result:
|
||||
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
|
||||
exec_vol[idx - last_step_range[0]] = order.deal_amount
|
||||
|
||||
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
|
||||
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
|
||||
exec_vol *= self.position / (exec_vol.sum())
|
||||
|
||||
market_volume = cast(
|
||||
IndexData,
|
||||
self.exchange.get_volume(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
market_price = cast(
|
||||
IndexData,
|
||||
self.exchange.get_deal_price(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
direction=self.order.direction,
|
||||
),
|
||||
)
|
||||
found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)]
|
||||
total_time_list = _get_all_timestamps(start_time, end_time)
|
||||
market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list)
|
||||
market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list)
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
# Get data from the current level executor's indicator
|
||||
current_trade_account = self.executor.trade_account
|
||||
current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self.order,
|
||||
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=current_df.iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.cur_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
self.cur_time = self._next_time()
|
||||
|
||||
def generate_metrics_after_done(self) -> None:
|
||||
"""Generate metrics once the upper level execution is done"""
|
||||
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.backtest_data.ticks_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.DatetimeIndex,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, order.direction),
|
||||
)
|
||||
|
||||
@property
|
||||
def saoe_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self.order,
|
||||
cur_time=self.cur_time,
|
||||
position=self.position,
|
||||
history_exec=self.history_exec,
|
||||
history_steps=self.history_steps,
|
||||
metrics=self.metrics,
|
||||
backtest_data=self.backtest_data,
|
||||
ticks_per_step=self.ticks_per_step,
|
||||
ticks_index=self.backtest_data.ticks_index,
|
||||
ticks_for_order=self.backtest_data.ticks_for_order,
|
||||
)
|
||||
|
||||
|
||||
class SAOEMetrics(TypedDict):
|
||||
@@ -302,7 +27,7 @@ class SAOEMetrics(TypedDict):
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
@@ -349,6 +74,8 @@ class SAOEState(NamedTuple):
|
||||
"""The order we are dealing with."""
|
||||
cur_time: pd.Timestamp
|
||||
"""Current time, e.g., 9:30."""
|
||||
cur_step: int
|
||||
"""Current step, e.g., 0."""
|
||||
position: float
|
||||
"""Current remaining volume to execute."""
|
||||
history_exec: pd.DataFrame
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from types import GeneratorType
|
||||
from typing import Any, cast, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -15,14 +15,276 @@ from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, Order
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange
|
||||
from qlib.backtest.utils import LevelInfrastructure
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.native import load_backtest_data
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.backtest.utils import LevelInfrastructure, get_start_end_idx
|
||||
from qlib.constant import EPS, ONE_MIN, REG_CN
|
||||
from qlib.rl.data.native import IntradayBacktestData, load_backtest_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter
|
||||
from qlib.rl.utils.env_wrapper import BaseEnvWrapper
|
||||
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import dataframe_append, price_advantage
|
||||
from qlib.strategy.base import RLStrategy
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.index_data import IndexData
|
||||
from qlib.utils.time import get_day_min_idx_range
|
||||
|
||||
|
||||
def _get_all_timestamps(
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
granularity: pd.Timedelta = ONE_MIN,
|
||||
include_end: bool = True,
|
||||
) -> pd.DatetimeIndex:
|
||||
ret = []
|
||||
while start <= end:
|
||||
ret.append(start)
|
||||
start += granularity
|
||||
|
||||
if ret[-1] > end:
|
||||
ret.pop()
|
||||
if ret[-1] == end and not include_end:
|
||||
ret.pop()
|
||||
return pd.DatetimeIndex(ret)
|
||||
|
||||
|
||||
def fill_missing_data(
|
||||
original_data: np.ndarray,
|
||||
fill_method: Callable = np.nanmedian,
|
||||
) -> np.ndarray:
|
||||
"""Fill missing data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
original_data
|
||||
Original data without missing values.
|
||||
fill_method
|
||||
Method used to fill the missing data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The filled data.
|
||||
"""
|
||||
return np.nan_to_num(original_data, nan=fill_method(original_data))
|
||||
|
||||
|
||||
class SAOEStateAdapter:
|
||||
"""
|
||||
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state
|
||||
according to the execution results with additional information acquired from executors & exchange. For example,
|
||||
it gets the dealt order amount from execution results, and get the corresponding market price / volume from
|
||||
exchange.
|
||||
|
||||
Example usage::
|
||||
|
||||
adapter = SAOEStateAdapter(...)
|
||||
adapter.update(...)
|
||||
state = adapter.saoe_state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
trade_decision: BaseTradeDecision,
|
||||
executor: BaseExecutor,
|
||||
exchange: Exchange,
|
||||
ticks_per_step: int,
|
||||
backtest_data: IntradayBacktestData,
|
||||
) -> None:
|
||||
self.position = order.amount
|
||||
self.order = order
|
||||
self.executor = executor
|
||||
self.exchange = exchange
|
||||
self.backtest_data = backtest_data
|
||||
self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision)
|
||||
|
||||
self.twap_price = self.backtest_data.get_deal_price().mean()
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
|
||||
self.ticks_per_step = ticks_per_step
|
||||
|
||||
def _next_time(self) -> pd.Timestamp:
|
||||
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
|
||||
next_loc = current_loc + self.ticks_per_step
|
||||
next_loc = next_loc - next_loc % self.ticks_per_step
|
||||
if (
|
||||
next_loc < len(self.backtest_data.ticks_index)
|
||||
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
|
||||
):
|
||||
return self.backtest_data.ticks_index[next_loc]
|
||||
else:
|
||||
return self.order.end_time
|
||||
|
||||
def update(
|
||||
self,
|
||||
execute_result: list,
|
||||
last_step_range: Tuple[int, int],
|
||||
) -> None:
|
||||
last_step_size = last_step_range[1] - last_step_range[0] + 1
|
||||
start_time = self.backtest_data.ticks_index[last_step_range[0]]
|
||||
end_time = self.backtest_data.ticks_index[last_step_range[1]]
|
||||
|
||||
exec_vol = np.zeros(last_step_size)
|
||||
for order, _, __, ___ in execute_result:
|
||||
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
|
||||
exec_vol[idx - last_step_range[0]] = order.deal_amount
|
||||
|
||||
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
|
||||
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
|
||||
exec_vol *= self.position / (exec_vol.sum())
|
||||
|
||||
market_volume = cast(
|
||||
IndexData,
|
||||
self.exchange.get_volume(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
market_price = cast(
|
||||
IndexData,
|
||||
self.exchange.get_deal_price(
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(start_time),
|
||||
pd.Timestamp(end_time),
|
||||
method=None,
|
||||
direction=self.order.direction,
|
||||
),
|
||||
)
|
||||
market_price = fill_missing_data(np.array(market_price, dtype=float).reshape(-1))
|
||||
market_volume = fill_missing_data(np.array(market_volume, dtype=float).reshape(-1))
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
# Get data from the current level executor's indicator
|
||||
current_trade_account = self.executor.trade_account
|
||||
current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self.order,
|
||||
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=current_df.iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.cur_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
self.cur_time = self._next_time()
|
||||
|
||||
def generate_metrics_after_done(self) -> None:
|
||||
"""Generate metrics once the upper level execution is done"""
|
||||
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self.order,
|
||||
self.backtest_data.ticks_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.DatetimeIndex,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, order.direction),
|
||||
)
|
||||
|
||||
@property
|
||||
def saoe_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self.order,
|
||||
cur_time=self.cur_time,
|
||||
cur_step=self.executor.trade_calendar.get_trade_step() - self.start_idx,
|
||||
position=self.position,
|
||||
history_exec=self.history_exec,
|
||||
history_steps=self.history_steps,
|
||||
metrics=self.metrics,
|
||||
backtest_data=self.backtest_data,
|
||||
ticks_per_step=self.ticks_per_step,
|
||||
ticks_index=self.backtest_data.ticks_index,
|
||||
ticks_for_order=self.backtest_data.ticks_for_order,
|
||||
)
|
||||
|
||||
|
||||
class SAOEStrategy(RLStrategy):
|
||||
@@ -30,7 +292,7 @@ class SAOEStrategy(RLStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: object, # TODO: add accurate typehint later.
|
||||
policy: BasePolicy,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -47,11 +309,17 @@ class SAOEStrategy(RLStrategy):
|
||||
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
|
||||
self._last_step_range = (0, 0)
|
||||
|
||||
def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
|
||||
def _create_qlib_backtest_adapter(
|
||||
self,
|
||||
order: Order,
|
||||
trade_decision: BaseTradeDecision,
|
||||
trade_range: TradeRange,
|
||||
) -> SAOEStateAdapter:
|
||||
backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)
|
||||
|
||||
return SAOEStateAdapter(
|
||||
order=order,
|
||||
trade_decision=trade_decision,
|
||||
executor=self.executor,
|
||||
exchange=self.trade_exchange,
|
||||
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
|
||||
@@ -71,7 +339,9 @@ class SAOEStrategy(RLStrategy):
|
||||
self.adapter_dict = {}
|
||||
for decision in outer_trade_decision.get_decision():
|
||||
order = cast(Order, decision)
|
||||
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range)
|
||||
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(
|
||||
order, outer_trade_decision, trade_range
|
||||
)
|
||||
|
||||
def get_saoe_state_by_order(self, order: Order) -> SAOEState:
|
||||
return self.adapter_dict[order.key_by_day].saoe_state
|
||||
@@ -166,11 +436,10 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
policy: dict | BasePolicy,
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: object = None, # TODO: add accurate typehint later.
|
||||
network: dict | torch.nn.Module | None = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
backtest: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super(SAOEIntStrategy, self).__init__(
|
||||
@@ -181,8 +450,6 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._backtest = backtest
|
||||
|
||||
self._state_interpreter: StateInterpreter = init_instance_by_config(
|
||||
state_interpreter,
|
||||
accept_types=StateInterpreter,
|
||||
@@ -221,21 +488,9 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
def set_env(self, env: BaseEnvWrapper) -> None:
|
||||
# TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper.
|
||||
# We should decompose the interpreters with EnvWrapper in the future and we should remove this method
|
||||
# after that.
|
||||
|
||||
self._env = env
|
||||
self._state_interpreter.env = self._action_interpreter.env = self._env
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
|
||||
# In backtest, env.reset() needs to be manually called since there is no outer trainer to call it
|
||||
if self._backtest:
|
||||
self._env.reset()
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
|
||||
@@ -268,10 +523,6 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act
|
||||
exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]
|
||||
|
||||
# In backtest, env.step() needs to be manually called since there is no outer trainer to call it
|
||||
if self._backtest:
|
||||
self._env.step(None)
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
|
||||
@@ -7,7 +7,7 @@ import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
@@ -152,6 +152,13 @@ class Trainer:
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_policy_state_dict(ckpt_path: Path) -> OrderedDict:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
if "vessel" in state_dict:
|
||||
state_dict = state_dict["vessel"]["policy"]
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load all states into current trainer."""
|
||||
self.vessel.load_state_dict(state_dict["vessel"])
|
||||
|
||||
@@ -48,24 +48,9 @@ class EnvWrapperStatus(TypedDict):
|
||||
reward_history: list
|
||||
|
||||
|
||||
class BaseEnvWrapper(
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Base env wrapper for RL environments. It has two implementations:
|
||||
- EnvWrapper: Qlib-based RL environment used in training.
|
||||
- CollectDataEnvWrapper: Dummy environment used in collect_data_loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in BaseEnvWrapper.")
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
@@ -129,8 +114,6 @@ class EnvWrapper(
|
||||
# 3. Avoid circular reference.
|
||||
# 4. When the components get serialized, we can throw away the env without any burden.
|
||||
# (though this part is not implemented yet)
|
||||
super().__init__()
|
||||
|
||||
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
|
||||
if obj is not None:
|
||||
obj.env = weakref.proxy(self) # type: ignore
|
||||
@@ -263,19 +246,5 @@ class EnvWrapper(
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
|
||||
class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
|
||||
"""Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop."""
|
||||
|
||||
def reset(self, **kwargs: Any) -> None:
|
||||
self.status = EnvWrapperStatus(
|
||||
cur_step=0,
|
||||
done=False,
|
||||
initial_state=None,
|
||||
obs_history=[],
|
||||
action_history=[],
|
||||
reward_history=[],
|
||||
)
|
||||
|
||||
def step(self, policy_action: Any = None, **kwargs: Any) -> None:
|
||||
self.status["cur_step"] += 1
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
||||
@@ -473,7 +473,8 @@ class PortAnaRecord(ACRecordTemp):
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -511,7 +512,7 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if _analysis_freq not in indicator_dict:
|
||||
warnings.warn(f"the freq {_analysis_freq} indicator is not found")
|
||||
else:
|
||||
indicators_normal = indicator_dict.get(_analysis_freq)
|
||||
indicators_normal = indicator_dict.get(_analysis_freq)[0]
|
||||
if self.indicator_analysis_method is None:
|
||||
analysis_df = indicator_analysis(indicators_normal)
|
||||
else:
|
||||
|
||||
@@ -107,7 +107,7 @@ class FileStrTest(TestAutoData):
|
||||
)
|
||||
|
||||
# ffr valid
|
||||
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
|
||||
ffr_dict = indicator_dict["1day"][0]["ffr"].to_dict()
|
||||
ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict}
|
||||
assert np.isclose(ffr_dict["2020-01-03"], dealt_num_for_1000 / 1000)
|
||||
assert np.isclose(ffr_dict["2020-01-06"], 0)
|
||||
|
||||
@@ -125,7 +125,7 @@ class TestHFBacktest(TestAutoData):
|
||||
# NOTE: please refer to the docs of format_decisions
|
||||
# NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!!
|
||||
f_dec = format_decisions(decisions)
|
||||
print(indicator["1day"])
|
||||
print(indicator["1day"][0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
@@ -183,8 +183,6 @@ def test_interpreter() -> None:
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
interpreter_action.env = CollectDataEnvWrapper()
|
||||
interpreter_action.env.reset()
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
|
||||
@@ -20,7 +20,6 @@ from qlib.rl.data.pickle_styled import PickleProcessedDataProvider
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
@@ -186,10 +185,6 @@ def test_interpreter():
|
||||
assert np.sum(obs["data_processed"][60:]) == 0
|
||||
|
||||
# second step: action
|
||||
interpreter_action.env = CollectDataEnvWrapper()
|
||||
interpreter_action_twap.env = CollectDataEnvWrapper()
|
||||
interpreter_action.env.reset()
|
||||
interpreter_action_twap.env.reset()
|
||||
action = interpreter_action(simulator.get_state(), 1)
|
||||
assert action == 15 / 20
|
||||
|
||||
@@ -260,8 +255,6 @@ def test_twap_strategy(finite_env_type):
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
|
||||
action_interp = TwapRelativeActionInterpreter()
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
policy = AllOne(state_interp.observation_space, action_interp.action_space)
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
@@ -291,8 +284,6 @@ def test_cn_ppo_strategy():
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
|
||||
@@ -324,8 +315,6 @@ def test_ppo_train():
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
action_interp.env = CollectDataEnvWrapper()
|
||||
action_interp.env.reset()
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user