diff --git a/.gitignore b/.gitignore index a3cc7c0e3..51f6654c3 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,9 @@ qlib/VERSION.txt qlib/data/_libs/expanding.cpp qlib/data/_libs/rolling.cpp examples/estimator/estimator_example/ +examples/rl/data/ +examples/rl/checkpoints/ +examples/rl/outputs/ *.egg-info/ diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 000000000..db5cdf20d --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,55 @@ +This folder contains a simple example of how to run Qlib RL. It contains: + +``` +. +├── experiment_config +│ ├── backtest # Backtest config +│ └── training # Training config +├── README.md # Readme (the current file) +└── scripts # Scripts for data pre-processing +``` + +## Data preparation + +Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data: + +``` +azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive +mv qlib_rl_example_data data +``` + +The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run: + +``` +bash scripts/data_pipeline.sh +``` + +After the execution finishes, the `data/` directory should be like: + +``` +data +├── backtest_orders.csv +├── bin +├── csv +├── pickle +├── pickle_dataframe +└── training_order_split +``` + +## Run training + +Run: + +``` +python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml +``` + +After training, checkpoints will be stored under `checkpoints/`. + +## Run backtest + +``` +python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py +``` + +The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`. diff --git a/examples/rl/experiment_config/backtest/config.py b/examples/rl/experiment_config/backtest/config.py new file mode 100644 index 000000000..9ac835789 --- /dev/null +++ b/examples/rl/experiment_config/backtest/config.py @@ -0,0 +1,53 @@ +_base_ = ["./twap.yml"] + +strategies = { + "_delete_": True, + "30min": { + "class": "TWAPStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": {}, + }, + "1day": { + "class": "SAOEIntStrategy", + "module_path": "qlib.rl.order_execution.strategy", + "kwargs": { + "state_interpreter": { + "class": "FullHistoryStateInterpreter", + "module_path": "qlib.rl.order_execution.interpreter", + "kwargs": { + "max_step": 8, + "data_ticks": 240, + "data_dim": 6, + "processed_data_provider": { + "class": "PickleProcessedDataProvider", + "module_path": "qlib.rl.data.pickle_styled", + "kwargs": { + "data_dir": "./data/pickle_dataframe/feature", + }, + }, + }, + }, + "action_interpreter": { + "class": "CategoricalActionInterpreter", + "module_path": "qlib.rl.order_execution.interpreter", + "kwargs": { + "values": 14, + "max_step": 8, + }, + }, + "network": { + "class": "Recurrent", + "module_path": "qlib.rl.order_execution.network", + "kwargs": {}, + }, + "policy": { + "class": "PPO", + "module_path": "qlib.rl.order_execution.policy", + "kwargs": { + "lr": 1.0e-4, + "weight_file": "./checkpoints/latest.pth", + }, + }, + }, + }, +} diff --git a/examples/rl/experiment_config/backtest/twap.yml b/examples/rl/experiment_config/backtest/twap.yml new file mode 100644 index 000000000..e0c342502 --- /dev/null +++ b/examples/rl/experiment_config/backtest/twap.yml @@ -0,0 +1,21 @@ +order_file: ./data/backtest_orders.csv +start_time: "9:45" +end_time: "14:44" +qlib: + provider_uri_1min: ./data/bin + feature_root_dir: ./data/pickle + feature_columns_today: [ + "$open", "$high", "$low", "$close", "$vwap", "$volume", + ] + feature_columns_yesterday: [ + "$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1", + ] +exchange: + limit_threshold: ['$close == 0', '$close == 0'] + deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"] + volume_threshold: + all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"] + buy: ["current", "$close"] + sell: ["current", "$close"] +strategies: {} # Placeholder +concurrency: 5 diff --git a/examples/rl/experiment_config/training/config.yml b/examples/rl/experiment_config/training/config.yml new file mode 100644 index 000000000..7e50d3eee --- /dev/null +++ b/examples/rl/experiment_config/training/config.yml @@ -0,0 +1,59 @@ +simulator: + time_per_step: 30 + vol_limit: null +env: + concurrency: 1 + parallel_mode: dummy +action_interpreter: + class: CategoricalActionInterpreter + kwargs: + values: 14 + max_step: 8 + module_path: qlib.rl.order_execution.interpreter +state_interpreter: + class: FullHistoryStateInterpreter + kwargs: + data_dim: 6 + data_ticks: 240 + max_step: 8 + processed_data_provider: + class: PickleProcessedDataProvider + module_path: qlib.rl.data.pickle_styled + kwargs: + data_dir: ./data/pickle_dataframe/feature + module_path: qlib.rl.order_execution.interpreter +reward: + class: PAPenaltyReward + kwargs: + penalty: 100.0 + module_path: qlib.rl.order_execution.reward +data: + source: + order_dir: ./data/training_order_split + data_dir: ./data/pickle_dataframe/backtest + total_time: 240 + default_start_time: 0 + default_end_time: 240 + proc_data_dim: 6 + num_workers: 0 + queue_size: 20 +network: + class: Recurrent + module_path: qlib.rl.order_execution.network +policy: + class: PPO + kwargs: + lr: 0.0001 + module_path: qlib.rl.order_execution.policy +runtime: + seed: 42 + use_cuda: false +trainer: + max_epoch: 2 + repeat_per_collect: 5 + earlystop_patience: 2 + episode_per_collect: 20 + batch_size: 16 + val_every_n_epoch: 1 + checkpoint_path: ./checkpoints + checkpoint_every_n_iters: 1 diff --git a/examples/rl/scripts/collect_pickle_dataframe.py b/examples/rl/scripts/collect_pickle_dataframe.py new file mode 100644 index 000000000..8950ec203 --- /dev/null +++ b/examples/rl/scripts/collect_pickle_dataframe.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import pickle +import pandas as pd +from tqdm import tqdm + +os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True) + +for tag in ("backtest", "feature"): + df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb")) + df = pd.concat(list(df.values())).reset_index() + df["date"] = df["datetime"].dt.date.astype("datetime64") + instruments = sorted(set(df["instrument"])) + + os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True) + for instrument in tqdm(instruments): + cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) + cur = cur.set_index(["instrument", "datetime", "date"]) + pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb")) diff --git a/examples/rl/scripts/data_pipeline.sh b/examples/rl/scripts/data_pipeline.sh new file mode 100644 index 000000000..c15b8fbe5 --- /dev/null +++ b/examples/rl/scripts/data_pipeline.sh @@ -0,0 +1,14 @@ +# Generate `bin` format data +set -e +python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min + +# Generate pickle format data +python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml +if [ -e stat/ ]; then + rm -r stat/ +fi +python scripts/collect_pickle_dataframe.py + +# Sample orders +python scripts/gen_training_orders.py +python scripts/gen_backtest_orders.py diff --git a/examples/rl/scripts/gen_backtest_orders.py b/examples/rl/scripts/gen_backtest_orders.py new file mode 100644 index 000000000..c3d0e4ef9 --- /dev/null +++ b/examples/rl/scripts/gen_backtest_orders.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import pandas as pd +import numpy as np +import pickle + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=20220926) +parser.add_argument("--num_order", type=int, default=10) +args = parser.parse_args() + +np.random.seed(args.seed) + +path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file +df = pickle.load(open(path, "rb")).reset_index() +df["date"] = df["datetime"].dt.date.astype("datetime64") + +instruments = sorted(set(df["instrument"])) +df_list = [] +for instrument in instruments: + print(instrument) + + cur_df = df[df["instrument"] == instrument] + + dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]])) + + n = args.num_order + df_list.append( + pd.DataFrame({ + "date": sorted(np.random.choice(dates, size=n, replace=False)), + "instrument": [instrument] * n, + "amount": np.random.randint(low=3, high=11, size=n) * 100.0, + "order_type": np.random.randint(low=0, high=2, size=n), + }).set_index(["date", "instrument"]), + ) + +total_df = pd.concat(df_list) +total_df.to_csv("data/backtest_orders.csv") diff --git a/examples/rl/scripts/gen_pickle_data.py b/examples/rl/scripts/gen_pickle_data.py new file mode 100755 index 000000000..3cb74f314 --- /dev/null +++ b/examples/rl/scripts/gen_pickle_data.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import yaml +import argparse +import os +from copy import deepcopy + +from qlib.contrib.data.highfreq_provider import HighFreqProvider + +loader = yaml.FullLoader + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="config.yml") + parser.add_argument("-d", "--dest", type=str, default=".") + parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock") + args = parser.parse_args() + + conf = yaml.load(open(args.config), Loader=loader) + + for k, v in conf.items(): + if isinstance(v, dict) and "path" in v: + v["path"] = os.path.join(args.dest, v["path"]) + provider = HighFreqProvider(**conf) + + # Gen dataframe + if "feature_conf" in conf: + feature = provider._gen_dataframe(deepcopy(provider.feature_conf)) + if "backtest_conf" in conf: + backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf)) + + provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/' + provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/' + # Split by date + if args.split == "date" or args.split == "both": + provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature") + provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest") + + # Split by stock + if args.split == "stock" or args.split == "both": + provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature") + provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest") diff --git a/examples/rl/scripts/gen_training_orders.py b/examples/rl/scripts/gen_training_orders.py new file mode 100644 index 000000000..07383c860 --- /dev/null +++ b/examples/rl/scripts/gen_training_orders.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import pandas as pd +import numpy as np +import pickle + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=20220926) +parser.add_argument("--stock", type=str, default="AAPL") +parser.add_argument("--train_size", type=int, default=10) +parser.add_argument("--valid_size", type=int, default=2) +parser.add_argument("--test_size", type=int, default=2) +args = parser.parse_args() + +np.random.seed(args.seed) + +os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True) + +for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)): + path = os.path.join("data", "pickle", f"backtest{group}.pkl") + df = pickle.load(open(path, "rb")).reset_index() + df["date"] = df["datetime"].dt.date.astype("datetime64") + + dates = sorted(set([str(d).split(" ")[0] for d in df["date"]])) + + data_df = pd.DataFrame({ + "date": sorted(np.random.choice(dates, size=n, replace=False)), + "instrument": [args.stock] * n, + "amount": np.random.randint(low=3, high=11, size=n) * 100.0, + "order_type": [0] * n, + }).set_index(["date", "instrument"]) + + os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True) + pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb")) diff --git a/examples/rl/scripts/pickle_data_config.yml b/examples/rl/scripts/pickle_data_config.yml new file mode 100755 index 000000000..7813f7d38 --- /dev/null +++ b/examples/rl/scripts/pickle_data_config.yml @@ -0,0 +1,57 @@ +# start & end time for training/validation/test datasets +start_time: !!str &start 2020-01-01 +end_time: !!str &end 2020-07-31 +train_end_time: !!str &tend 2020-03-31 +valid_start_time: !!str &vstart 2020-04-01 +valid_end_time: !!str &vend 2020-05-31 +test_start_time: !!str &tstart 2020-06-01 +# the instrument set +instruments: &ins all +# qlib related configuration +qlib_conf: + provider_uri: ./data/bin # path to generated qlib bin + redis_port: 233 +feature_conf: + path: ./data/pickle/feature.pkl # output path of feature + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: HighFreqGeneralHandler + module_path: qlib.contrib.data.highfreq_handler + kwargs: + start_time: *start + end_time: *end + fit_start_time: *start + fit_end_time: *tend + instruments: *ins + day_length: 240 # how many minutes in one trading day + infer_processors: + - class: HighFreqNorm + module_path: qlib.contrib.data.highfreq_processor + kwargs: + feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization) + norm_groups: + price: 10 + volume: 2 + segments: + train: !!python/tuple [*start, *tend] + valid: !!python/tuple [*vstart, *vend] + test: !!python/tuple [*tstart, *end] +backtest_conf: + path: ./data/pickle/backtest.pkl # output path of backtest + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: HighFreqGeneralBacktestHandler + module_path: qlib.contrib.data.highfreq_handler + kwargs: + start_time: *start + end_time: *end + instruments: *ins + day_length: 240 + segments: + train: !!python/tuple [*start, *tend] + valid: !!python/tuple [*vstart, *vend] + test: !!python/tuple [*tstart, *end] diff --git a/qlib/contrib/data/highfreq_provider.py b/qlib/contrib/data/highfreq_provider.py index 7e47da0bf..704b37f72 100644 --- a/qlib/contrib/data/highfreq_provider.py +++ b/qlib/contrib/data/highfreq_provider.py @@ -4,6 +4,7 @@ import datetime from typing import Optional import qlib +from qlib import get_module_logger from qlib.data import D from qlib.config import REG_CN from qlib.utils import init_instance_by_config @@ -12,7 +13,6 @@ from qlib.data.data import Cal from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut import pickle as pkl from joblib import Parallel, delayed -from utilsd.logging import print_log class HighFreqProvider: @@ -41,6 +41,7 @@ class HighFreqProvider: self.label_conf = label_conf self.backtest_conf = backtest_conf self.qlib_conf = qlib_conf + self.logger = get_module_logger("HighFreqProvider") def get_pre_datasets(self): """Generate the training, validation and test datasets for prediction @@ -125,7 +126,7 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - print_log("Dataset exists, load from disk.", __name__) + self.logger.info("Dataset exists, load from disk.", __name__) # res = dataset.prepare(['train', 'valid', 'test']) with open(path, "rb") as f: @@ -134,11 +135,11 @@ class HighFreqProvider: res = [data[i] for i in datasets] else: res = data.prepare(datasets) - print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - print_log("Generating dataset", __name__) + self.logger.info("Generating dataset", __name__) start_time = time.time() self._prepare_calender_cache() dataset = init_instance_by_config(config) @@ -157,7 +158,7 @@ class HighFreqProvider: with open(path[:-4] + "test.pkl", "wb") as f: pkl.dump(testset, f) res = [data[i] for i in datasets] - print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) + self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) return res def _gen_data(self, config, datasets=["train", "valid", "test"]): @@ -167,7 +168,7 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - print_log("Dataset exists, load from disk.", __name__) + self.logger.info("Dataset exists, load from disk.", __name__) # res = dataset.prepare(['train', 'valid', 'test']) with open(path, "rb") as f: @@ -176,18 +177,18 @@ class HighFreqProvider: res = [data[i] for i in datasets] else: res = data.prepare(datasets) - print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - print_log("Generating dataset", __name__) + self.logger.info("Generating dataset", __name__) start_time = time.time() self._prepare_calender_cache() dataset = init_instance_by_config(config) dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) res = dataset.prepare(datasets) - print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) + self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) return res def _gen_dataset(self, config): @@ -197,21 +198,21 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - print_log("Dataset exists, load from disk.", __name__) + self.logger.info("Dataset exists, load from disk.", __name__) with open(path, "rb") as f: dataset = pkl.load(f) - print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - print_log("Generating dataset", __name__) + self.logger.info("Generating dataset", __name__) self._prepare_calender_cache() dataset = init_instance_by_config(config) - print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) dataset.prepare(["train", "valid", "test"]) - print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__) dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) return dataset @@ -224,15 +225,15 @@ class HighFreqProvider: if os.path.isfile(path + "tmp_dataset.pkl"): start = time.time() - print_log("Dataset exists, load from disk.", __name__) + self.logger.info("Dataset exists, load from disk.", __name__) else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - print_log("Generating dataset", __name__) + self.logger.info("Generating dataset", __name__) self._prepare_calender_cache() dataset = init_instance_by_config(config) - print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") @@ -265,15 +266,15 @@ class HighFreqProvider: if os.path.isfile(path + "tmp_dataset.pkl"): start = time.time() - print_log("Dataset exists, load from disk.", __name__) + self.logger.info("Dataset exists, load from disk.", __name__) else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - print_log("Generating dataset", __name__) + self.logger.info("Generating dataset", __name__) self._prepare_calender_cache() dataset = init_instance_by_config(config) - print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") diff --git a/qlib/rl/contrib/__init__.py b/qlib/rl/contrib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 4d3d3cf4b..4cd101150 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse import copy +import os import pickle from collections import defaultdict from pathlib import Path @@ -365,6 +366,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram else: res = pd.concat(res) + if not output_path.exists(): + os.makedirs(output_path) res.to_csv(output_path / "summary.csv") return res diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py new file mode 100644 index 000000000..f043dda64 --- /dev/null +++ b/qlib/rl/contrib/train_onpolicy.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import argparse +import os +import random +from pathlib import Path +from typing import cast, List, Optional + +import numpy as np +import pandas as pd +import torch +import yaml +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.constant import ONE_MIN +from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution import SingleAssetOrderExecutionSimple +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, train +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch import nn +from torch.utils.data import Dataset + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_orders(order_dir: Path) -> pd.DataFrame: + if os.path.isfile(order_dir): + return pd.read_pickle(order_dir) + else: + orders = [] + for file in order_dir.iterdir(): + order_data = pd.read_pickle(file) + orders.append(order_data) + return pd.concat(orders) + + +class LazyLoadDataset(Dataset): + def __init__( + self, + order_file_path: Path, + data_dir: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._order_file_path = order_file_path + self._order_df = _read_orders(order_file_path).reset_index() + + self._data_dir = data_dir + self._ticks_index: Optional[pd.DatetimeIndex] = None + + def __len__(self) -> int: + return len(self._order_df) + + def __getitem__(self, index: int) -> Order: + row = self._order_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + backtest_data = load_simple_intraday_backtest_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + ) + self._ticks_index = [t - date for t in backtest_data.get_time_index()] + + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(int(row["order_type"])), + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return order + + +def train_and_test( + env_config: dict, + simulator_config: dict, + trainer_config: dict, + data_config: dict, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + policy: BasePolicy, + reward: Reward, +) -> None: + order_root_path = Path(data_config["source"]["order_dir"]) + + def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: + return SingleAssetOrderExecutionSimple( + order=order, + data_dir=Path(data_config["source"]["data_dir"]), + ticks_per_step=simulator_config["time_per_step"], + deal_price_type=data_config["source"].get("deal_price_column", "close"), + vol_threshold=simulator_config["vol_limit"], + ) + + train_dataset = LazyLoadDataset( + order_file_path=order_root_path / "train", + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time"], + default_end_time_index=data_config["source"]["default_end_time"], + ) + valid_dataset = LazyLoadDataset( + order_file_path=order_root_path / "valid", + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time"], + default_end_time_index=data_config["source"]["default_end_time"], + ) + + callbacks = [] + if "checkpoint_path" in trainer_config: + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]), + every_n_iters=trainer_config["checkpoint_every_n_iters"], + save_latest="copy", + ), + ) + + trainer_kwargs = { + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + } + vessel_kwargs = { + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + } + + train( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs=trainer_kwargs, + vessel_kwargs=vessel_kwargs, + ) + + +def main(config: dict) -> None: + if "seed" in config["runtime"]: + seed_everything(config["runtime"]["seed"]) + + state_config = config["state_interpreter"] + state_interpreter: StateInterpreter = init_instance_by_config(state_config) + + action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + reward: Reward = init_instance_by_config(config["reward"]) + + # Create torch network + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + network: nn.Module = init_instance_by_config(config["network"]) + + # Create policy + config["policy"]["kwargs"].update( + { + "network": network, + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + ) + policy: BasePolicy = init_instance_by_config(config["policy"]) + + use_cuda = config["runtime"].get("use_cuda", False) + if use_cuda: + policy.cuda() + + train_and_test( + env_config=config["env"], + simulator_config=config["simulator"], + data_config=config["data"], + trainer_config=config["trainer"], + action_interpreter=action_interpreter, + state_interpreter=state_interpreter, + policy=policy, + reward=reward, + ) + + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + args = parser.parse_args() + + with open(args.config_path, "r") as input_stream: + config = yaml.safe_load(input_stream) + + main(config) diff --git a/qlib/rl/trainer/api.py b/qlib/rl/trainer/api.py index e9f48df24..aea99dc3d 100644 --- a/qlib/rl/trainer/api.py +++ b/qlib/rl/trainer/api.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, cast +from typing import Any, Callable, Dict, List, Sequence, cast from tianshou.policy import BasePolicy @@ -23,8 +23,8 @@ def train( initial_states: Sequence[InitialStateType], policy: BasePolicy, reward: Reward, - vessel_kwargs: dict[str, Any], - trainer_kwargs: dict[str, Any], + vessel_kwargs: Dict[str, Any], + trainer_kwargs: Dict[str, Any], ) -> None: """Train a policy with the parallelism provided by RL framework. @@ -69,7 +69,7 @@ def backtest( action_interpreter: ActionInterpreter, initial_states: Sequence[InitialStateType], policy: BasePolicy, - logger: LogWriter | list[LogWriter], + logger: LogWriter | List[LogWriter], reward: Reward | None = None, finite_env_type: FiniteEnvType = "subproc", concurrency: int = 2, diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py index c76b674c6..e5422075e 100644 --- a/qlib/rl/trainer/callbacks.py +++ b/qlib/rl/trainer/callbacks.py @@ -8,6 +8,7 @@ Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of from __future__ import annotations import copy +import os import shutil import time from datetime import datetime @@ -253,7 +254,7 @@ class Checkpoint(Callback): latest_pth = self.dirpath / "latest.pth" # Remove first before saving - if self.save_latest and latest_pth.exists(): + if self.save_latest and (latest_pth.exists() or os.path.islink(latest_pth)): latest_pth.unlink() if self.save_latest == "link": diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index f8f4c548d..66a185447 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -3,10 +3,11 @@ from __future__ import annotations +import collections import copy from contextlib import AbstractContextManager, contextmanager from pathlib import Path -from typing import Any, Iterable, Sequence, TypeVar, cast +from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast import torch @@ -83,7 +84,7 @@ class Trainer: current_iter: int """Current iteration (collect) of training.""" - loggers: list[LogWriter] + loggers: List[LogWriter] """A list of log writers.""" def __init__( @@ -91,8 +92,8 @@ class Trainer: *, max_iters: int | None = None, val_every_n_iters: int | None = None, - loggers: LogWriter | list[LogWriter] | None = None, - callbacks: list[Callback] | None = None, + loggers: LogWriter | List[LogWriter] | None = None, + callbacks: List[Callback] | None = None, finite_env_type: FiniteEnvType = "subproc", concurrency: int = 2, fast_dev_run: int | None = None, @@ -109,7 +110,7 @@ class Trainer: self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel())) - self.callbacks: list[Callback] = callbacks if callbacks is not None else [] + self.callbacks: List[Callback] = callbacks if callbacks is not None else [] self.finite_env_type = finite_env_type self.concurrency = concurrency self.fast_dev_run = fast_dev_run @@ -164,13 +165,13 @@ class Trainer: self.current_stage = state_dict["current_stage"] self.metrics = state_dict["metrics"] - def named_callbacks(self) -> dict[str, Callback]: + def named_callbacks(self) -> Dict[str, Callback]: """Retrieve a collection of callbacks where each one has a name. Useful when saving checkpoints. """ return _named_collection(self.callbacks) - def named_loggers(self) -> dict[str, LogWriter]: + def named_loggers(self) -> Dict[str, LogWriter]: """Retrieve a collection of loggers where each one has a name. Useful when saving checkpoints. """ @@ -328,16 +329,13 @@ def _wrap_context(obj): yield obj -def _named_collection(seq: Sequence[T]) -> dict[str, T]: +def _named_collection(seq: Sequence[T]) -> Dict[str, T]: """Convert a list into a dict, where each item is named with its type.""" res = {} + retry_cnt: collections.Counter = collections.Counter() for item in seq: typename = type(item).__name__.lower() - if typename not in res: - res[typename] = item - else: - # names are auto-labelled as earlystop1, earlystop2, ... - for retry in range(1, 1000): - if f"{typename}{retry}" not in res: - res[f"{typename}{retry}"] = item + key = typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}" + retry_cnt[typename] += 1 + res[key] = item return res diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py index e1ad0cb98..6cd2eb3e9 100644 --- a/qlib/rl/trainer/vessel.py +++ b/qlib/rl/trainer/vessel.py @@ -63,15 +63,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, """Override this to create a seed iterator for testing.""" raise SeedIteratorNotAvailable("Seed iterator for testing is not available.") - def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]: + def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]: """Implement this to train one iteration. In RL, one iteration usually refers to one collect.""" raise NotImplementedError() - def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: """Implement this to validate the policy once.""" raise NotImplementedError() - def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: """Implement this to evaluate the policy on test environment once.""" raise NotImplementedError() @@ -82,15 +82,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, value = np.mean(value) _logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}") - def log_dict(self, data: dict[str, Any]) -> None: + def log_dict(self, data: Dict[str, Any]) -> None: for name, value in data.items(): self.log(name, value) - def state_dict(self) -> dict: + def state_dict(self) -> Dict: """Return a checkpoint of current vessel state.""" return {"policy": self.policy.state_dict()} - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: Dict) -> None: """Restore a checkpoint from a previously saved state dict.""" self.policy.load_state_dict(state_dict["policy"]) @@ -125,7 +125,7 @@ class TrainingVessel(TrainingVesselBase): test_initial_states: Sequence[InitialStateType] | None = None, buffer_size: int = 20000, episode_per_iter: int = 1000, - update_kwargs: dict[str, Any] = cast(Dict[str, Any], None), + update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None), ): self.simulator_fn = simulator_fn # type: ignore self.state_interpreter = state_interpreter @@ -161,7 +161,7 @@ class TrainingVessel(TrainingVesselBase): return DataQueue(test_initial_states, repeat=1) return super().test_seed_iterator() - def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: """Create a collector and collects ``episode_per_iter`` episodes. Update the policy on the collected replay buffer. """ @@ -182,7 +182,7 @@ class TrainingVessel(TrainingVesselBase): self.log_dict(res) return res - def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: self.policy.eval() with vector_env.collector_guard(): @@ -191,7 +191,7 @@ class TrainingVessel(TrainingVesselBase): self.log_dict(res) return res - def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: self.policy.eval() with vector_env.collector_guard():