1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Migrate amc4th training (#1316)

* Migrate amc4th training

* Refine RL example scripts

* Resolve PR comments

Co-authored-by: luocy16 <luocy16@mails.tsinghua.edu.cn>
This commit is contained in:
Huoran Li
2022-10-19 10:17:43 +08:00
committed by GitHub
parent bc06f0301e
commit 3c62d131a5
19 changed files with 676 additions and 50 deletions

3
.gitignore vendored
View File

@@ -24,6 +24,9 @@ qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/
*.egg-info/

55
examples/rl/README.md Normal file
View File

@@ -0,0 +1,55 @@
This folder contains a simple example of how to run Qlib RL. It contains:
```
.
├── experiment_config
│ ├── backtest # Backtest config
│ └── training # Training config
├── README.md # Readme (the current file)
└── scripts # Scripts for data pre-processing
```
## Data preparation
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
```
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
mv qlib_rl_example_data data
```
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
```
bash scripts/data_pipeline.sh
```
After the execution finishes, the `data/` directory should be like:
```
data
├── backtest_orders.csv
├── bin
├── csv
├── pickle
├── pickle_dataframe
└── training_order_split
```
## Run training
Run:
```
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml
```
After training, checkpoints will be stored under `checkpoints/`.
## Run backtest
```
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py
```
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.

View File

@@ -0,0 +1,53 @@
_base_ = ["./twap.yml"]
strategies = {
"_delete_": True,
"30min": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {},
},
"1day": {
"class": "SAOEIntStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {
"state_interpreter": {
"class": "FullHistoryStateInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"max_step": 8,
"data_ticks": 240,
"data_dim": 6,
"processed_data_provider": {
"class": "PickleProcessedDataProvider",
"module_path": "qlib.rl.data.pickle_styled",
"kwargs": {
"data_dir": "./data/pickle_dataframe/feature",
},
},
},
},
"action_interpreter": {
"class": "CategoricalActionInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"values": 14,
"max_step": 8,
},
},
"network": {
"class": "Recurrent",
"module_path": "qlib.rl.order_execution.network",
"kwargs": {},
},
"policy": {
"class": "PPO",
"module_path": "qlib.rl.order_execution.policy",
"kwargs": {
"lr": 1.0e-4,
"weight_file": "./checkpoints/latest.pth",
},
},
},
},
}

View File

@@ -0,0 +1,21 @@
order_file: ./data/backtest_orders.csv
start_time: "9:45"
end_time: "14:44"
qlib:
provider_uri_1min: ./data/bin
feature_root_dir: ./data/pickle
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$volume",
]
feature_columns_yesterday: [
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
]
exchange:
limit_threshold: ['$close == 0', '$close == 0']
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
volume_threshold:
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
buy: ["current", "$close"]
sell: ["current", "$close"]
strategies: {} # Placeholder
concurrency: 5

View File

@@ -0,0 +1,59 @@
simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 1
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 14
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 6
data_ticks: 240
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 100.0
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/training_order_split
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time: 0
default_end_time: 240
proc_data_dim: 6
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 2
repeat_per_collect: 5
earlystop_patience: 2
episode_per_collect: 20
batch_size: 16
val_every_n_epoch: 1
checkpoint_path: ./checkpoints
checkpoint_every_n_iters: 1

View File

@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from tqdm import tqdm
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
for instrument in tqdm(instruments):
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))

View File

@@ -0,0 +1,14 @@
# Generate `bin` format data
set -e
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
# Generate pickle format data
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
if [ -e stat/ ]; then
rm -r stat/
fi
python scripts/collect_pickle_dataframe.py
# Sample orders
python scripts/gen_training_orders.py
python scripts/gen_backtest_orders.py

View File

@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--num_order", type=int, default=10)
args = parser.parse_args()
np.random.seed(args.seed)
path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
df_list = []
for instrument in instruments:
print(instrument)
cur_df = df[df["instrument"] == instrument]
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
n = args.num_order
df_list.append(
pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [instrument] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": np.random.randint(low=0, high=2, size=n),
}).set_index(["date", "instrument"]),
)
total_df = pd.concat(df_list)
total_df.to_csv("data/backtest_orders.csv")

View File

@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import yaml
import argparse
import os
from copy import deepcopy
from qlib.contrib.data.highfreq_provider import HighFreqProvider
loader = yaml.FullLoader
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="config.yml")
parser.add_argument("-d", "--dest", type=str, default=".")
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
args = parser.parse_args()
conf = yaml.load(open(args.config), Loader=loader)
for k, v in conf.items():
if isinstance(v, dict) and "path" in v:
v["path"] = os.path.join(args.dest, v["path"])
provider = HighFreqProvider(**conf)
# Gen dataframe
if "feature_conf" in conf:
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
if "backtest_conf" in conf:
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))
provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/'
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/'
# Split by date
if args.split == "date" or args.split == "both":
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")
# Split by stock
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")

View File

@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--stock", type=str, default="AAPL")
parser.add_argument("--train_size", type=int, default=10)
parser.add_argument("--valid_size", type=int, default=2)
parser.add_argument("--test_size", type=int, default=2)
args = parser.parse_args()
np.random.seed(args.seed)
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
data_df = pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [args.stock] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": [0] * n,
}).set_index(["date", "instrument"])
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))

View File

@@ -0,0 +1,57 @@
# start & end time for training/validation/test datasets
start_time: !!str &start 2020-01-01
end_time: !!str &end 2020-07-31
train_end_time: !!str &tend 2020-03-31
valid_start_time: !!str &vstart 2020-04-01
valid_end_time: !!str &vend 2020-05-31
test_start_time: !!str &tstart 2020-06-01
# the instrument set
instruments: &ins all
# qlib related configuration
qlib_conf:
provider_uri: ./data/bin # path to generated qlib bin
redis_port: 233
feature_conf:
path: ./data/pickle/feature.pkl # output path of feature
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: HighFreqGeneralHandler
module_path: qlib.contrib.data.highfreq_handler
kwargs:
start_time: *start
end_time: *end
fit_start_time: *start
fit_end_time: *tend
instruments: *ins
day_length: 240 # how many minutes in one trading day
infer_processors:
- class: HighFreqNorm
module_path: qlib.contrib.data.highfreq_processor
kwargs:
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
norm_groups:
price: 10
volume: 2
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]
backtest_conf:
path: ./data/pickle/backtest.pkl # output path of backtest
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: HighFreqGeneralBacktestHandler
module_path: qlib.contrib.data.highfreq_handler
kwargs:
start_time: *start
end_time: *end
instruments: *ins
day_length: 240
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]

View File

@@ -4,6 +4,7 @@ import datetime
from typing import Optional
import qlib
from qlib import get_module_logger
from qlib.data import D
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
@@ -12,7 +13,6 @@ from qlib.data.data import Cal
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
import pickle as pkl
from joblib import Parallel, delayed
from utilsd.logging import print_log
class HighFreqProvider:
@@ -41,6 +41,7 @@ class HighFreqProvider:
self.label_conf = label_conf
self.backtest_conf = backtest_conf
self.qlib_conf = qlib_conf
self.logger = get_module_logger("HighFreqProvider")
def get_pre_datasets(self):
"""Generate the training, validation and test datasets for prediction
@@ -125,7 +126,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
print_log("Dataset exists, load from disk.", __name__)
self.logger.info("Dataset exists, load from disk.", __name__)
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -134,11 +135,11 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
print_log("Generating dataset", __name__)
self.logger.info("Generating dataset", __name__)
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
@@ -157,7 +158,7 @@ class HighFreqProvider:
with open(path[:-4] + "test.pkl", "wb") as f:
pkl.dump(testset, f)
res = [data[i] for i in datasets]
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
return res
def _gen_data(self, config, datasets=["train", "valid", "test"]):
@@ -167,7 +168,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
print_log("Dataset exists, load from disk.", __name__)
self.logger.info("Dataset exists, load from disk.", __name__)
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -176,18 +177,18 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
print_log("Generating dataset", __name__)
self.logger.info("Generating dataset", __name__)
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
res = dataset.prepare(datasets)
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
return res
def _gen_dataset(self, config):
@@ -197,21 +198,21 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
print_log("Dataset exists, load from disk.", __name__)
self.logger.info("Dataset exists, load from disk.", __name__)
with open(path, "rb") as f:
dataset = pkl.load(f)
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
print_log("Generating dataset", __name__)
self.logger.info("Generating dataset", __name__)
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
dataset.prepare(["train", "valid", "test"])
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
return dataset
@@ -224,15 +225,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
print_log("Dataset exists, load from disk.", __name__)
self.logger.info("Dataset exists, load from disk.", __name__)
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
print_log("Generating dataset", __name__)
self.logger.info("Generating dataset", __name__)
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")
@@ -265,15 +266,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
print_log("Dataset exists, load from disk.", __name__)
self.logger.info("Dataset exists, load from disk.", __name__)
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
print_log("Generating dataset", __name__)
self.logger.info("Generating dataset", __name__)
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")

View File

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import argparse
import copy
import os
import pickle
from collections import defaultdict
from pathlib import Path
@@ -365,6 +366,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
else:
res = pd.concat(res)
if not output_path.exists():
os.makedirs(output_path)
res.to_csv(output_path / "summary.csv")
return res

View File

@@ -0,0 +1,219 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import random
from pathlib import Path
from typing import cast, List, Optional
import numpy as np
import pandas as pd
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, train
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch import nn
from torch.utils.data import Dataset
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
class LazyLoadDataset(Dataset):
def __init__(
self,
order_file_path: Path,
data_dir: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_file_path = order_file_path
self._order_df = _read_orders(order_file_path).reset_index()
self._data_dir = data_dir
self._ticks_index: Optional[pd.DatetimeIndex] = None
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
backtest_data = load_simple_intraday_backtest_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
)
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def train_and_test(
env_config: dict,
simulator_config: dict,
trainer_config: dict,
data_config: dict,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
) -> None:
order_root_path = Path(data_config["source"]["order_dir"])
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
deal_price_type=data_config["source"].get("deal_price_column", "close"),
vol_threshold=simulator_config["vol_limit"],
)
train_dataset = LazyLoadDataset(
order_file_path=order_root_path / "train",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
valid_dataset = LazyLoadDataset(
order_file_path=order_root_path / "valid",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
callbacks = []
if "checkpoint_path" in trainer_config:
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]),
every_n_iters=trainer_config["checkpoint_every_n_iters"],
save_latest="copy",
),
)
trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
}
vessel_kwargs = {
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
}
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs=trainer_kwargs,
vessel_kwargs=vessel_kwargs,
)
def main(config: dict) -> None:
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
state_config = config["state_interpreter"]
state_interpreter: StateInterpreter = init_instance_by_config(state_config)
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
# Create torch network
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
network: nn.Module = init_instance_by_config(config["network"])
# Create policy
config["policy"]["kwargs"].update(
{
"network": network,
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
if use_cuda:
policy.cuda()
train_and_test(
env_config=config["env"],
simulator_config=config["simulator"],
data_config=config["data"],
trainer_config=config["trainer"],
action_interpreter=action_interpreter,
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
)
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Any, Callable, Sequence, cast
from typing import Any, Callable, Dict, List, Sequence, cast
from tianshou.policy import BasePolicy
@@ -23,8 +23,8 @@ def train(
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: dict[str, Any],
trainer_kwargs: dict[str, Any],
vessel_kwargs: Dict[str, Any],
trainer_kwargs: Dict[str, Any],
) -> None:
"""Train a policy with the parallelism provided by RL framework.
@@ -69,7 +69,7 @@ def backtest(
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
logger: LogWriter | List[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,

View File

@@ -8,6 +8,7 @@ Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of
from __future__ import annotations
import copy
import os
import shutil
import time
from datetime import datetime
@@ -253,7 +254,7 @@ class Checkpoint(Callback):
latest_pth = self.dirpath / "latest.pth"
# Remove first before saving
if self.save_latest and latest_pth.exists():
if self.save_latest and (latest_pth.exists() or os.path.islink(latest_pth)):
latest_pth.unlink()
if self.save_latest == "link":

View File

@@ -3,10 +3,11 @@
from __future__ import annotations
import collections
import copy
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Iterable, Sequence, TypeVar, cast
from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast
import torch
@@ -83,7 +84,7 @@ class Trainer:
current_iter: int
"""Current iteration (collect) of training."""
loggers: list[LogWriter]
loggers: List[LogWriter]
"""A list of log writers."""
def __init__(
@@ -91,8 +92,8 @@ class Trainer:
*,
max_iters: int | None = None,
val_every_n_iters: int | None = None,
loggers: LogWriter | list[LogWriter] | None = None,
callbacks: list[Callback] | None = None,
loggers: LogWriter | List[LogWriter] | None = None,
callbacks: List[Callback] | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
fast_dev_run: int | None = None,
@@ -109,7 +110,7 @@ class Trainer:
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
self.callbacks: List[Callback] = callbacks if callbacks is not None else []
self.finite_env_type = finite_env_type
self.concurrency = concurrency
self.fast_dev_run = fast_dev_run
@@ -164,13 +165,13 @@ class Trainer:
self.current_stage = state_dict["current_stage"]
self.metrics = state_dict["metrics"]
def named_callbacks(self) -> dict[str, Callback]:
def named_callbacks(self) -> Dict[str, Callback]:
"""Retrieve a collection of callbacks where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.callbacks)
def named_loggers(self) -> dict[str, LogWriter]:
def named_loggers(self) -> Dict[str, LogWriter]:
"""Retrieve a collection of loggers where each one has a name.
Useful when saving checkpoints.
"""
@@ -328,16 +329,13 @@ def _wrap_context(obj):
yield obj
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
def _named_collection(seq: Sequence[T]) -> Dict[str, T]:
"""Convert a list into a dict, where each item is named with its type."""
res = {}
retry_cnt: collections.Counter = collections.Counter()
for item in seq:
typename = type(item).__name__.lower()
if typename not in res:
res[typename] = item
else:
# names are auto-labelled as earlystop1, earlystop2, ...
for retry in range(1, 1000):
if f"{typename}{retry}" not in res:
res[f"{typename}{retry}"] = item
key = typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}"
retry_cnt[typename] += 1
res[key] = item
return res

View File

@@ -63,15 +63,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
"""Override this to create a seed iterator for testing."""
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
raise NotImplementedError()
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
"""Implement this to validate the policy once."""
raise NotImplementedError()
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
"""Implement this to evaluate the policy on test environment once."""
raise NotImplementedError()
@@ -82,15 +82,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
value = np.mean(value)
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
def log_dict(self, data: dict[str, Any]) -> None:
def log_dict(self, data: Dict[str, Any]) -> None:
for name, value in data.items():
self.log(name, value)
def state_dict(self) -> dict:
def state_dict(self) -> Dict:
"""Return a checkpoint of current vessel state."""
return {"policy": self.policy.state_dict()}
def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: Dict) -> None:
"""Restore a checkpoint from a previously saved state dict."""
self.policy.load_state_dict(state_dict["policy"])
@@ -125,7 +125,7 @@ class TrainingVessel(TrainingVesselBase):
test_initial_states: Sequence[InitialStateType] | None = None,
buffer_size: int = 20000,
episode_per_iter: int = 1000,
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
):
self.simulator_fn = simulator_fn # type: ignore
self.state_interpreter = state_interpreter
@@ -161,7 +161,7 @@ class TrainingVessel(TrainingVesselBase):
return DataQueue(test_initial_states, repeat=1)
return super().test_seed_iterator()
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
"""Create a collector and collects ``episode_per_iter`` episodes.
Update the policy on the collected replay buffer.
"""
@@ -182,7 +182,7 @@ class TrainingVessel(TrainingVesselBase):
self.log_dict(res)
return res
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():
@@ -191,7 +191,7 @@ class TrainingVessel(TrainingVesselBase):
self.log_dict(res)
return res
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():