mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f5f3a6af0 | ||
|
|
2f8fc8d28a | ||
|
|
3e9ccd3ad2 |
@@ -179,7 +179,7 @@ def get_strategy_executor(
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
@@ -197,12 +197,15 @@ def get_strategy_executor(
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
if isinstance(exchange_kwargs, Exchange):
|
||||
trade_exchange = exchange_kwargs
|
||||
else:
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
|
||||
@@ -56,6 +56,7 @@ def collect_data_loop(
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -74,6 +75,8 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -83,7 +86,8 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
|
||||
@@ -177,7 +177,7 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
@@ -263,6 +263,9 @@ class Exchange:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
if isinstance(limit_threshold, list):
|
||||
assert len(limit_threshold) == 2
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
@@ -325,7 +328,7 @@ class Exchange:
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
if key in ("buy", "all"):
|
||||
@@ -803,7 +806,7 @@ class Exchange:
|
||||
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
|
||||
if limit[0] == "current":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
|
||||
@@ -16,13 +16,12 @@ import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
@@ -124,105 +123,13 @@ def _generate_report(
|
||||
return report
|
||||
|
||||
|
||||
def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
A new simulator will be created and used for every single-day order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
reports = []
|
||||
decisions = []
|
||||
for _, row in orders.iterrows():
|
||||
date = pd.Timestamp(row["datetime"])
|
||||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(row["direction"]),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
_report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: _report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
time_range: Tuple[str, str],
|
||||
exchange_config: dict,
|
||||
strategy_config: dict,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
data_granularity: str = "1min",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
@@ -250,44 +157,42 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
strategy_config = {
|
||||
top_strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"file": orders,
|
||||
"trade_range": TradeRangeByTime(
|
||||
pd.Timestamp(backtest_config["start_time"]).time(),
|
||||
pd.Timestamp(backtest_config["end_time"]).time(),
|
||||
pd.Timestamp(time_range[0]).time(),
|
||||
pd.Timestamp(time_range[1]).time(),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=strategy_config,
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
data_granularity=data_granularity,
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
exchange_config = {
|
||||
**exchange_config,
|
||||
**{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
"freq": data_granularity,
|
||||
},
|
||||
}
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
@@ -295,7 +200,7 @@ def single_with_collect_data_loop(
|
||||
)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
|
||||
|
||||
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
|
||||
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
|
||||
@@ -315,46 +220,54 @@ def single_with_collect_data_loop(
|
||||
|
||||
|
||||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config.pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
init_qlib(backtest_config["simulator"]["qlib"])
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
split="stock",
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
|
||||
single = single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
|
||||
for task_config in backtest_config["tasks"]:
|
||||
order_df = read_order_file(task_config["order_file"])
|
||||
exchange_config = task_config["exchange"]
|
||||
cash_limit = exchange_config.pop("cash_limit")
|
||||
generate_report = backtest_config["runtime"]["generate_report"]
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
#
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
time_range=task_config["time_range"],
|
||||
exchange_config=task_config["exchange"],
|
||||
strategy_config=backtest_config["strategies"],
|
||||
split="stock",
|
||||
data_granularity=task_config["data_granularity"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
|
||||
output_path = Path(backtest_config["output_dir"])
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
return res
|
||||
|
||||
#
|
||||
output_path = Path(task_config["output_dir"])
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
# return res # TODO
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -362,6 +275,7 @@ if __name__ == "__main__":
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
@@ -374,9 +288,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = get_backtest_config_fromfile(args.config_path)
|
||||
if args.n_jobs is not None:
|
||||
config["concurrency"] = args.n_jobs
|
||||
|
||||
config_parser = BacktestConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
if args.n_jobs is not None: # Overwrite concurrency
|
||||
config["runtime"]["concurrency"] = args.n_jobs
|
||||
|
||||
backtest(
|
||||
backtest_config=config,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -30,7 +31,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def parse_backtest_config(path: str) -> dict:
|
||||
def load_config(path: str) -> dict:
|
||||
abs_path = os.path.abspath(path)
|
||||
check_file_exist(abs_path)
|
||||
|
||||
@@ -65,43 +66,154 @@ def parse_backtest_config(path: str) -> dict:
|
||||
base_file_name = [base_file_name]
|
||||
|
||||
for f in base_file_name:
|
||||
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
config = merge_a_into_b(a=config, b=base_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _convert_all_list_to_tuple(config: dict) -> dict:
|
||||
for k, v in config.items():
|
||||
if isinstance(v, list):
|
||||
config[k] = tuple(v)
|
||||
elif isinstance(v, dict):
|
||||
config[k] = _convert_all_list_to_tuple(v)
|
||||
return config
|
||||
class BacktestConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
self._simulator_config = self._parse_simulator()
|
||||
self._exchange_config = self._simulator_config.pop("exchange")
|
||||
config = {
|
||||
"strategies": self.raw_config["strategies"],
|
||||
"runtime": self.raw_config["runtime"],
|
||||
"tasks": self._parse_tasks(),
|
||||
"simulator": self._simulator_config,
|
||||
}
|
||||
return config
|
||||
|
||||
def _parse_tasks(self) -> dict:
|
||||
task_config = []
|
||||
for task in self.raw_config["tasks"]:
|
||||
if "output_dir" not in task:
|
||||
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
|
||||
if "exchange" not in task:
|
||||
task["exchange"] = copy.deepcopy(self._exchange_config)
|
||||
else:
|
||||
task["exchange"] = self._complete_exchange_config(task["exchange"])
|
||||
task_config.append(task)
|
||||
|
||||
return task_config
|
||||
|
||||
def _complete_exchange_config(self, exchange_config: dict) -> dict:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
|
||||
return exchange_config
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
|
||||
return {
|
||||
"qlib": config["qlib"],
|
||||
"exchange": self._complete_exchange_config(config["exchange"]),
|
||||
}
|
||||
|
||||
|
||||
def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = parse_backtest_config(path)
|
||||
class TrainingConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
def parse(self) -> dict:
|
||||
return {
|
||||
"general": self._parse_general(),
|
||||
"policy": self.raw_config["policy"],
|
||||
"interpreter": self.raw_config["interpreter"],
|
||||
"runtime": self._parse_runtime(),
|
||||
"training": self._parse_training(),
|
||||
"simulator": self._parse_simulator(),
|
||||
}
|
||||
|
||||
backtest_config_default = {
|
||||
"debug_single_stock": None,
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs_backtest/",
|
||||
"generate_report": False,
|
||||
"data_granularity": "1min",
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
def _parse_general(self) -> dict:
|
||||
default = {
|
||||
"freq": "1min",
|
||||
"extra_module_paths": [],
|
||||
}
|
||||
return {**default, **self.raw_config["general"]}
|
||||
|
||||
return backtest_config
|
||||
def _parse_runtime(self) -> dict:
|
||||
default = {
|
||||
"seed": None,
|
||||
"use_cuda": False,
|
||||
"concurrency": 1,
|
||||
"parallel_mode": "dummy",
|
||||
}
|
||||
return {**default, **self.raw_config["runtime"]}
|
||||
|
||||
def _parse_training(self) -> dict:
|
||||
default = {
|
||||
"max_epoch": 100,
|
||||
"repeat_per_collect": 2,
|
||||
"earlystop_patience": float("inf"),
|
||||
"episode_per_collect": 10000,
|
||||
"batch_size": 256,
|
||||
"val_every_n_epoch": None,
|
||||
"checkpoint_path": "./outputs",
|
||||
"checkpoint_every_n_iters": 10,
|
||||
}
|
||||
|
||||
config = self.raw_config["training"]
|
||||
assert "order_dir" in config
|
||||
|
||||
return {**default, **config}
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
sim_type = config["type"]
|
||||
assert sim_type in ("simple", "full")
|
||||
|
||||
if sim_type == "simple":
|
||||
return {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"feature_columns_today": config["data"]["feature_columns_today"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"time_per_step": config["time_per_step"],
|
||||
"vol_limit": config["vol_limit"],
|
||||
}
|
||||
else:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **config["exchange"]}
|
||||
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
|
||||
|
||||
ret_config = {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"qlib": {
|
||||
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
|
||||
},
|
||||
"exchange": exchange_config,
|
||||
}
|
||||
|
||||
return ret_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(parser.parse())
|
||||
|
||||
362
qlib/rl/contrib/train.py
Normal file
362
qlib/rl/contrib/train.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, cast, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl import Simulator
|
||||
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(freq: int) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{freq}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
def _freq_str_to_int(freq: str) -> int:
|
||||
if freq.endswith("min"):
|
||||
return int(freq.replace("min", ""))
|
||||
elif freq.endswith("hour"):
|
||||
return int(freq.replace("hour", "") * 60)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized freq string: {freq}")
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_df: pd.DataFrame,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = order_df
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
|
||||
df = df.copy()
|
||||
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
|
||||
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
|
||||
return dfs
|
||||
|
||||
|
||||
def _get_simulator_factory(
|
||||
sim_type: str,
|
||||
data_dir: Path,
|
||||
freq_min: int,
|
||||
simulator_config: dict,
|
||||
) -> Callable[[Order], Simulator]:
|
||||
if sim_type == "simple":
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
simulator = SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_dir,
|
||||
feature_columns_today=simulator_config["data"]["feature_columns_today"],
|
||||
data_granularity=freq_min,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_simple
|
||||
elif sim_type == "full":
|
||||
init_qlib(simulator_config["qlib"])
|
||||
executor_config = get_executor_config(freq_min)
|
||||
exchange_config = simulator_config["exchange"]
|
||||
|
||||
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_full
|
||||
else:
|
||||
raise ValueError(f"Unknown simulator type: {sim_type}")
|
||||
|
||||
|
||||
def train_and_test(
|
||||
freq: str,
|
||||
concurrency: int,
|
||||
parallel_mode: str,
|
||||
training_config: dict,
|
||||
simulator_config: dict,
|
||||
policy: BasePolicy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
freq_min: int = _freq_str_to_int(freq)
|
||||
order_root_path = Path(training_config["order_dir"])
|
||||
feature_root_dir = simulator_config["data"]["feature_root_dir"]
|
||||
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
|
||||
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
|
||||
|
||||
_simulator_factory = _get_simulator_factory(
|
||||
sim_type=simulator_config["type"],
|
||||
data_dir=feature_root_dir,
|
||||
freq_min=freq_min,
|
||||
simulator_config=simulator_config,
|
||||
)
|
||||
|
||||
# Load orders
|
||||
load_data_tags = []
|
||||
orders_by_tag = {}
|
||||
if run_training:
|
||||
load_data_tags += ["train", "valid"]
|
||||
if run_backtest:
|
||||
load_data_tags += ["test"]
|
||||
for tag in load_data_tags:
|
||||
order_df = _read_orders(order_root_path / tag).reset_index()
|
||||
dfs = _split_order_df_by_instrument(order_df, concurrency)
|
||||
datasets = [
|
||||
LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_df=df,
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
|
||||
)
|
||||
for df in dfs
|
||||
]
|
||||
orders_by_tag[tag] = datasets
|
||||
|
||||
if run_training:
|
||||
callbacks: List[Callback] = [
|
||||
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
|
||||
Checkpoint(
|
||||
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=training_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
EarlyStopping(
|
||||
patience=training_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
),
|
||||
]
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
|
||||
trainer_kwargs={
|
||||
"max_iters": training_config["max_epoch"],
|
||||
"finite_env_type": parallel_mode,
|
||||
"concurrency": concurrency,
|
||||
"val_every_n_iters": training_config["val_every_n_epoch"],
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": training_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": training_config["batch_size"],
|
||||
"repeat": training_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(training_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=parallel_mode, # type: ignore[arg-type]
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
seed = config["runtime"]["seed"]
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
for extra_module_path in config["general"]["extra_module_paths"]:
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
|
||||
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
# Create torch network
|
||||
if "network" in config["policy"]:
|
||||
network_config = config["policy"]["network"]
|
||||
network_config["kwargs"] = {
|
||||
**network_config.get("kwargs", {}),
|
||||
**{"obs_space": state_interpreter.observation_space},
|
||||
}
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
|
||||
|
||||
# Create policy
|
||||
policy_config = config["policy"]["policy"]
|
||||
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
|
||||
policy: BasePolicy = init_instance_by_config(policy_config)
|
||||
|
||||
use_cuda = config["runtime"]["use_cuda"]
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
freq=config["general"]["freq"],
|
||||
concurrency=config["runtime"]["concurrency"],
|
||||
parallel_mode=config["runtime"]["parallel_mode"],
|
||||
training_config=config["training"],
|
||||
simulator_config=config["simulator"],
|
||||
policy=policy,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_parser = TrainingConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -1,268 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
simulator_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
feature_columns_today=data_config["source"]["feature_columns_today"],
|
||||
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
|
||||
data_granularity=data_granularity,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
for extra_module_path in config["env"].get("extra_module_paths", []):
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
simulator_config=config["simulator"],
|
||||
data_config=config["data"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.data.dataset import DatasetH
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -140,6 +141,16 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_handler_pickle(path: str) -> DatasetH:
|
||||
with open(path, "rb") as fstream:
|
||||
obj = pickle.load(fstream)
|
||||
return obj
|
||||
|
||||
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
@@ -151,7 +162,6 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -161,31 +171,17 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
dataset = _load_handler_pickle(path)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
@@ -193,10 +189,14 @@ def load_handler_intraday_processed_data(
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
@@ -229,5 +229,4 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import List, Sequence, cast
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
@@ -158,6 +157,15 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_df_pickle(path: str) -> pd.DataFrame:
|
||||
df = pd.read_pickle(path)
|
||||
return df
|
||||
|
||||
|
||||
class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
|
||||
|
||||
@@ -166,36 +174,18 @@ class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool,
|
||||
) -> None:
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
if isinstance(data_dir, str):
|
||||
data_dir = Path(data_dir)
|
||||
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
|
||||
df = _load_df_pickle(str(path))
|
||||
df = df.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
self.today = df[feature_columns_today]
|
||||
self.yesterday = df[feature_columns_yesterday]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -213,25 +203,38 @@ def load_simple_intraday_backtest_data(
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_pickle_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
return PickleIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(self, data_dir: Path) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._backtest = backtest
|
||||
self._feature_columns_today = feature_columns_today
|
||||
self._feature_columns_yesterday = feature_columns_yesterday
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
@@ -244,8 +247,9 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
feature_dim=feature_dim,
|
||||
time_index=time_index,
|
||||
feature_columns_today=self._feature_columns_today,
|
||||
feature_columns_yesterday=self._feature_columns_yesterday,
|
||||
backtest=self._backtest,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, List, Optional
|
||||
import cachetools
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
@@ -16,6 +17,18 @@ from .state import SAOEState
|
||||
from .strategy import SAOEStateAdapter, SAOEStrategy
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda order, _: order.stock_id,
|
||||
)
|
||||
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
|
||||
exchange_kwargs = {
|
||||
**exchange_config,
|
||||
"codes": [order.stock_id],
|
||||
}
|
||||
return get_exchange(**exchange_kwargs)
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
@@ -76,7 +89,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
exchange_kwargs=_create_exchange(order, exchange_config),
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
|
||||
@@ -90,6 +103,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
show_progress=False,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ from pathlib import Path
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
@@ -42,8 +43,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
feature_columns_yesterday
|
||||
Columns of yesterday's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
@@ -80,7 +79,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
feature_columns_yesterday: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
vol_threshold: Optional[float] = None,
|
||||
@@ -92,7 +90,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.order = order
|
||||
self.data_dir = data_dir
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
@@ -122,14 +119,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
|
||||
def get_backtest_data(self) -> BaseIntradayBacktestData:
|
||||
try:
|
||||
data = load_handler_intraday_processed_data(
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self.data_dir,
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=self.feature_columns_yesterday,
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=False,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
|
||||
@@ -451,6 +451,7 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: dict | torch.nn.Module | None = None,
|
||||
immediate_addition: bool = False,
|
||||
outer_trade_decision: BaseTradeDecision | None = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
@@ -501,9 +502,12 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
self.immediate_addition = immediate_addition
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
self.trade_amount_planned = collections.defaultdict(float)
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
@@ -539,9 +543,15 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
|
||||
order = cast(Order, decision)
|
||||
if self.immediate_addition:
|
||||
self.trade_amount_planned[order.stock_id] += exec_vol
|
||||
amount_planned = self.trade_amount_planned[order.stock_id]
|
||||
amount_finished = order.amount - state.position
|
||||
exec_vol = min(state.position, amount_planned - amount_finished)
|
||||
|
||||
if exec_vol != 0:
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWithDetails(
|
||||
|
||||
@@ -20,7 +20,7 @@ def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: Dict[str, Any],
|
||||
@@ -39,7 +39,9 @@ def train(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
@@ -67,7 +69,7 @@ def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
@@ -87,7 +89,9 @@ def backtest(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
|
||||
@@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from contextlib import AbstractContextManager, ExitStack, contextmanager
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
@@ -206,45 +207,50 @@ class Trainer:
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
_logger.info(msg)
|
||||
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
|
||||
vessel.val_seed_iterators()
|
||||
) as valid_iterators:
|
||||
train_vector_env = self.venv_from_iterator(train_iterators)
|
||||
valid_vector_env = self.venv_from_iterator(valid_iterators)
|
||||
|
||||
self.initialize_iter()
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
print(msg)
|
||||
_logger.info(msg)
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
self.initialize_iter()
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
self.vessel.train(train_vector_env)
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
self.vessel.validate(valid_vector_env)
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
@@ -265,16 +271,16 @@ class Trainer:
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
with _wrap_context(vessel.test_seed_iterators()) as iterators:
|
||||
vector_env = self.venv_from_iterator(iterators)
|
||||
self.vessel.test(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory():
|
||||
def env_factory(iterator):
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
@@ -300,7 +306,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
env_factory,
|
||||
[partial(env_factory, iterator=it) for it in iterators],
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
@@ -334,8 +340,11 @@ class Trainer:
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(e) for e in obj]
|
||||
stack.pop_all().close()
|
||||
elif isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
@@ -49,19 +49,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
@@ -120,9 +124,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
train_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
val_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
test_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
@@ -132,34 +136,49 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.train_initial_states = None if train_initial_states is None else train_initial_states
|
||||
self.val_initial_states = None if val_initial_states is None else val_initial_states
|
||||
self.test_initial_states = None if test_initial_states is None else test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
|
||||
train_initial_states = [
|
||||
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().train_seed_iterators()
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
|
||||
val_initial_states = [
|
||||
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().val_seed_iterators()
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
|
||||
test_initial_states = [
|
||||
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().test_seed_iterators()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
|
||||
@@ -258,6 +258,46 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step2(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = {}
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
tmp = super().step(valid_act, request_id)
|
||||
|
||||
for obs_next, rew, done, info in zip(*tmp):
|
||||
obs_next = self._postproc_env_obs(obs_next)
|
||||
result[info["env_id"]] = [obs_next, rew, done, info]
|
||||
|
||||
# logging
|
||||
for i, r in result.items():
|
||||
if i in self._alive_env_ids and r[0] is not None:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
||||
for _, reward, __, info in result.values():
|
||||
self._set_default_info(info)
|
||||
self._set_default_rew(reward)
|
||||
for r in result.values():
|
||||
if r[0] is None:
|
||||
r[0] = self._get_default_obs()
|
||||
if r[1] is None:
|
||||
r[1] = self._get_default_rew()
|
||||
if r[3] is None:
|
||||
r[3] = self._get_default_info()
|
||||
|
||||
ret = list(map(np.stack, zip(*result.values())))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
@@ -311,7 +351,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
||||
|
||||
|
||||
def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_factories: List[Callable[..., gym.Env]],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
@@ -334,9 +374,10 @@ def vectorize_env(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_factory
|
||||
Callable to instantiate one single ``gym.Env``.
|
||||
All concurrent workers will have the same ``env_factory``.
|
||||
env_factories
|
||||
Callables to instantiate one single ``gym.Env``.
|
||||
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
|
||||
the same env_factory. Otherwise, each worker will have its own env_factory.
|
||||
env_type
|
||||
dummy or subproc or shmem. Corresponding to
|
||||
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
@@ -358,6 +399,8 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
assert len(env_factories) in (1, concurrency)
|
||||
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
@@ -366,4 +409,7 @@ def vectorize_env(
|
||||
|
||||
finite_env_cls = env_type_cls_mapping[env_type]
|
||||
|
||||
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|
||||
if len(env_factories) == 1:
|
||||
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
|
||||
else:
|
||||
return finite_env_cls(logger, env_factories)
|
||||
|
||||
30
qlib/rl/utils/profiling.py
Normal file
30
qlib/rl/utils/profiling.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Generator
|
||||
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
|
||||
s = time.perf_counter()
|
||||
yield
|
||||
e = time.perf_counter()
|
||||
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
|
||||
|
||||
if out_path is not None:
|
||||
with open(out_path, "a") as fstream:
|
||||
fstream.write(msg + "\n")
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
|
||||
def lprofile(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
lp = LineProfiler()
|
||||
lpw = lp(func)
|
||||
res = lpw(*args, **kwargs)
|
||||
lp.print_stats()
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user