mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Migrate amc4th training (#1316)
* Migrate amc4th training * Refine RL example scripts * Resolve PR comments Co-authored-by: luocy16 <luocy16@mails.tsinghua.edu.cn>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -24,6 +24,9 @@ qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
examples/rl/outputs/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
|
||||
55
examples/rl/README.md
Normal file
55
examples/rl/README.md
Normal file
@@ -0,0 +1,55 @@
|
||||
This folder contains a simple example of how to run Qlib RL. It contains:
|
||||
|
||||
```
|
||||
.
|
||||
├── experiment_config
|
||||
│ ├── backtest # Backtest config
|
||||
│ └── training # Training config
|
||||
├── README.md # Readme (the current file)
|
||||
└── scripts # Scripts for data pre-processing
|
||||
```
|
||||
|
||||
## Data preparation
|
||||
|
||||
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
|
||||
|
||||
```
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
|
||||
mv qlib_rl_example_data data
|
||||
```
|
||||
|
||||
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
|
||||
|
||||
```
|
||||
bash scripts/data_pipeline.sh
|
||||
```
|
||||
|
||||
After the execution finishes, the `data/` directory should be like:
|
||||
|
||||
```
|
||||
data
|
||||
├── backtest_orders.csv
|
||||
├── bin
|
||||
├── csv
|
||||
├── pickle
|
||||
├── pickle_dataframe
|
||||
└── training_order_split
|
||||
```
|
||||
|
||||
## Run training
|
||||
|
||||
Run:
|
||||
|
||||
```
|
||||
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml
|
||||
```
|
||||
|
||||
After training, checkpoints will be stored under `checkpoints/`.
|
||||
|
||||
## Run backtest
|
||||
|
||||
```
|
||||
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py
|
||||
```
|
||||
|
||||
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
|
||||
53
examples/rl/experiment_config/backtest/config.py
Normal file
53
examples/rl/experiment_config/backtest/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
_base_ = ["./twap.yml"]
|
||||
|
||||
strategies = {
|
||||
"_delete_": True,
|
||||
"30min": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"1day": {
|
||||
"class": "SAOEIntStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {
|
||||
"state_interpreter": {
|
||||
"class": "FullHistoryStateInterpreter",
|
||||
"module_path": "qlib.rl.order_execution.interpreter",
|
||||
"kwargs": {
|
||||
"max_step": 8,
|
||||
"data_ticks": 240,
|
||||
"data_dim": 6,
|
||||
"processed_data_provider": {
|
||||
"class": "PickleProcessedDataProvider",
|
||||
"module_path": "qlib.rl.data.pickle_styled",
|
||||
"kwargs": {
|
||||
"data_dir": "./data/pickle_dataframe/feature",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"action_interpreter": {
|
||||
"class": "CategoricalActionInterpreter",
|
||||
"module_path": "qlib.rl.order_execution.interpreter",
|
||||
"kwargs": {
|
||||
"values": 14,
|
||||
"max_step": 8,
|
||||
},
|
||||
},
|
||||
"network": {
|
||||
"class": "Recurrent",
|
||||
"module_path": "qlib.rl.order_execution.network",
|
||||
"kwargs": {},
|
||||
},
|
||||
"policy": {
|
||||
"class": "PPO",
|
||||
"module_path": "qlib.rl.order_execution.policy",
|
||||
"kwargs": {
|
||||
"lr": 1.0e-4,
|
||||
"weight_file": "./checkpoints/latest.pth",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
21
examples/rl/experiment_config/backtest/twap.yml
Normal file
21
examples/rl/experiment_config/backtest/twap.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
order_file: ./data/backtest_orders.csv
|
||||
start_time: "9:45"
|
||||
end_time: "14:44"
|
||||
qlib:
|
||||
provider_uri_1min: ./data/bin
|
||||
feature_root_dir: ./data/pickle
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$volume",
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: ['$close == 0', '$close == 0']
|
||||
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
|
||||
volume_threshold:
|
||||
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
|
||||
buy: ["current", "$close"]
|
||||
sell: ["current", "$close"]
|
||||
strategies: {} # Placeholder
|
||||
concurrency: 5
|
||||
59
examples/rl/experiment_config/training/config.yml
Normal file
59
examples/rl/experiment_config/training/config.yml
Normal file
@@ -0,0 +1,59 @@
|
||||
simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 1
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 14
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 6
|
||||
data_ticks: 240
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
kwargs:
|
||||
penalty: 100.0
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/training_order_split
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
total_time: 240
|
||||
default_start_time: 0
|
||||
default_end_time: 240
|
||||
proc_data_dim: 6
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 2
|
||||
repeat_per_collect: 5
|
||||
earlystop_patience: 2
|
||||
episode_per_collect: 20
|
||||
batch_size: 16
|
||||
val_every_n_epoch: 1
|
||||
checkpoint_path: ./checkpoints
|
||||
checkpoint_every_n_iters: 1
|
||||
21
examples/rl/scripts/collect_pickle_dataframe.py
Normal file
21
examples/rl/scripts/collect_pickle_dataframe.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
|
||||
|
||||
for tag in ("backtest", "feature"):
|
||||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
|
||||
df = pd.concat(list(df.values())).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
|
||||
for instrument in tqdm(instruments):
|
||||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
|
||||
cur = cur.set_index(["instrument", "datetime", "date"])
|
||||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
|
||||
14
examples/rl/scripts/data_pipeline.sh
Normal file
14
examples/rl/scripts/data_pipeline.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
# Generate `bin` format data
|
||||
set -e
|
||||
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
|
||||
|
||||
# Generate pickle format data
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
if [ -e stat/ ]; then
|
||||
rm -r stat/
|
||||
fi
|
||||
python scripts/collect_pickle_dataframe.py
|
||||
|
||||
# Sample orders
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/gen_backtest_orders.py
|
||||
41
examples/rl/scripts/gen_backtest_orders.py
Normal file
41
examples/rl/scripts/gen_backtest_orders.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=20220926)
|
||||
parser.add_argument("--num_order", type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file
|
||||
df = pickle.load(open(path, "rb")).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
df_list = []
|
||||
for instrument in instruments:
|
||||
print(instrument)
|
||||
|
||||
cur_df = df[df["instrument"] == instrument]
|
||||
|
||||
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
|
||||
|
||||
n = args.num_order
|
||||
df_list.append(
|
||||
pd.DataFrame({
|
||||
"date": sorted(np.random.choice(dates, size=n, replace=False)),
|
||||
"instrument": [instrument] * n,
|
||||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
|
||||
"order_type": np.random.randint(low=0, high=2, size=n),
|
||||
}).set_index(["date", "instrument"]),
|
||||
)
|
||||
|
||||
total_df = pd.concat(df_list)
|
||||
total_df.to_csv("data/backtest_orders.csv")
|
||||
43
examples/rl/scripts/gen_pickle_data.py
Executable file
43
examples/rl/scripts/gen_pickle_data.py
Executable file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import yaml
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
from qlib.contrib.data.highfreq_provider import HighFreqProvider
|
||||
|
||||
loader = yaml.FullLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, default="config.yml")
|
||||
parser.add_argument("-d", "--dest", type=str, default=".")
|
||||
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
|
||||
args = parser.parse_args()
|
||||
|
||||
conf = yaml.load(open(args.config), Loader=loader)
|
||||
|
||||
for k, v in conf.items():
|
||||
if isinstance(v, dict) and "path" in v:
|
||||
v["path"] = os.path.join(args.dest, v["path"])
|
||||
provider = HighFreqProvider(**conf)
|
||||
|
||||
# Gen dataframe
|
||||
if "feature_conf" in conf:
|
||||
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
|
||||
if "backtest_conf" in conf:
|
||||
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))
|
||||
|
||||
provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/'
|
||||
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/'
|
||||
# Split by date
|
||||
if args.split == "date" or args.split == "both":
|
||||
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
|
||||
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")
|
||||
|
||||
# Split by stock
|
||||
if args.split == "stock" or args.split == "both":
|
||||
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
|
||||
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
|
||||
37
examples/rl/scripts/gen_training_orders.py
Normal file
37
examples/rl/scripts/gen_training_orders.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=20220926)
|
||||
parser.add_argument("--stock", type=str, default="AAPL")
|
||||
parser.add_argument("--train_size", type=int, default=10)
|
||||
parser.add_argument("--valid_size", type=int, default=2)
|
||||
parser.add_argument("--test_size", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
|
||||
|
||||
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
|
||||
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
|
||||
df = pickle.load(open(path, "rb")).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
|
||||
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
|
||||
|
||||
data_df = pd.DataFrame({
|
||||
"date": sorted(np.random.choice(dates, size=n, replace=False)),
|
||||
"instrument": [args.stock] * n,
|
||||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
|
||||
"order_type": [0] * n,
|
||||
}).set_index(["date", "instrument"])
|
||||
|
||||
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
|
||||
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))
|
||||
57
examples/rl/scripts/pickle_data_config.yml
Executable file
57
examples/rl/scripts/pickle_data_config.yml
Executable file
@@ -0,0 +1,57 @@
|
||||
# start & end time for training/validation/test datasets
|
||||
start_time: !!str &start 2020-01-01
|
||||
end_time: !!str &end 2020-07-31
|
||||
train_end_time: !!str &tend 2020-03-31
|
||||
valid_start_time: !!str &vstart 2020-04-01
|
||||
valid_end_time: !!str &vend 2020-05-31
|
||||
test_start_time: !!str &tstart 2020-06-01
|
||||
# the instrument set
|
||||
instruments: &ins all
|
||||
# qlib related configuration
|
||||
qlib_conf:
|
||||
provider_uri: ./data/bin # path to generated qlib bin
|
||||
redis_port: 233
|
||||
feature_conf:
|
||||
path: ./data/pickle/feature.pkl # output path of feature
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: HighFreqGeneralHandler
|
||||
module_path: qlib.contrib.data.highfreq_handler
|
||||
kwargs:
|
||||
start_time: *start
|
||||
end_time: *end
|
||||
fit_start_time: *start
|
||||
fit_end_time: *tend
|
||||
instruments: *ins
|
||||
day_length: 240 # how many minutes in one trading day
|
||||
infer_processors:
|
||||
- class: HighFreqNorm
|
||||
module_path: qlib.contrib.data.highfreq_processor
|
||||
kwargs:
|
||||
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
|
||||
norm_groups:
|
||||
price: 10
|
||||
volume: 2
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
test: !!python/tuple [*tstart, *end]
|
||||
backtest_conf:
|
||||
path: ./data/pickle/backtest.pkl # output path of backtest
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: HighFreqGeneralBacktestHandler
|
||||
module_path: qlib.contrib.data.highfreq_handler
|
||||
kwargs:
|
||||
start_time: *start
|
||||
end_time: *end
|
||||
instruments: *ins
|
||||
day_length: 240
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
test: !!python/tuple [*tstart, *end]
|
||||
@@ -4,6 +4,7 @@ import datetime
|
||||
from typing import Optional
|
||||
|
||||
import qlib
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
@@ -12,7 +13,6 @@ from qlib.data.data import Cal
|
||||
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
import pickle as pkl
|
||||
from joblib import Parallel, delayed
|
||||
from utilsd.logging import print_log
|
||||
|
||||
|
||||
class HighFreqProvider:
|
||||
@@ -41,6 +41,7 @@ class HighFreqProvider:
|
||||
self.label_conf = label_conf
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
self.logger = get_module_logger("HighFreqProvider")
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
@@ -125,7 +126,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -134,11 +135,11 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
@@ -157,7 +158,7 @@ class HighFreqProvider:
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
@@ -167,7 +168,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -176,18 +177,18 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
@@ -197,21 +198,21 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
@@ -224,15 +225,15 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
@@ -265,15 +266,15 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
|
||||
0
qlib/rl/contrib/__init__.py
Normal file
0
qlib/rl/contrib/__init__.py
Normal file
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -365,6 +366,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
res.to_csv(output_path / "summary.csv")
|
||||
return res
|
||||
|
||||
|
||||
219
qlib/rl/contrib/train_onpolicy.py
Normal file
219
qlib/rl/contrib/train_onpolicy.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, train
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
order_file_path: Path,
|
||||
data_dir: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_file_path = order_file_path
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
backtest_data = load_simple_intraday_backtest_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
)
|
||||
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
simulator_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
) -> None:
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
deal_price_type=data_config["source"].get("deal_price_column", "close"),
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
train_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "train",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
valid_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "valid",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
|
||||
callbacks = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]),
|
||||
every_n_iters=trainer_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
|
||||
trainer_kwargs = {
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
vessel_kwargs = {
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
}
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs=trainer_kwargs,
|
||||
vessel_kwargs=vessel_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict) -> None:
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
state_config = config["state_interpreter"]
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(state_config)
|
||||
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
# Create torch network
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
network: nn.Module = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
config["policy"]["kwargs"].update(
|
||||
{
|
||||
"network": network,
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
simulator_config=config["simulator"],
|
||||
data_config=config["data"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
from typing import Any, Callable, Dict, List, Sequence, cast
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
@@ -23,8 +23,8 @@ def train(
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: dict[str, Any],
|
||||
trainer_kwargs: dict[str, Any],
|
||||
vessel_kwargs: Dict[str, Any],
|
||||
trainer_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Train a policy with the parallelism provided by RL framework.
|
||||
|
||||
@@ -69,7 +69,7 @@ def backtest(
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
|
||||
@@ -8,6 +8,7 @@ Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -253,7 +254,7 @@ class Checkpoint(Callback):
|
||||
latest_pth = self.dirpath / "latest.pth"
|
||||
|
||||
# Remove first before saving
|
||||
if self.save_latest and latest_pth.exists():
|
||||
if self.save_latest and (latest_pth.exists() or os.path.islink(latest_pth)):
|
||||
latest_pth.unlink()
|
||||
|
||||
if self.save_latest == "link":
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Sequence, TypeVar, cast
|
||||
from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
@@ -83,7 +84,7 @@ class Trainer:
|
||||
current_iter: int
|
||||
"""Current iteration (collect) of training."""
|
||||
|
||||
loggers: list[LogWriter]
|
||||
loggers: List[LogWriter]
|
||||
"""A list of log writers."""
|
||||
|
||||
def __init__(
|
||||
@@ -91,8 +92,8 @@ class Trainer:
|
||||
*,
|
||||
max_iters: int | None = None,
|
||||
val_every_n_iters: int | None = None,
|
||||
loggers: LogWriter | list[LogWriter] | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
loggers: LogWriter | List[LogWriter] | None = None,
|
||||
callbacks: List[Callback] | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
fast_dev_run: int | None = None,
|
||||
@@ -109,7 +110,7 @@ class Trainer:
|
||||
|
||||
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
|
||||
|
||||
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
|
||||
self.callbacks: List[Callback] = callbacks if callbacks is not None else []
|
||||
self.finite_env_type = finite_env_type
|
||||
self.concurrency = concurrency
|
||||
self.fast_dev_run = fast_dev_run
|
||||
@@ -164,13 +165,13 @@ class Trainer:
|
||||
self.current_stage = state_dict["current_stage"]
|
||||
self.metrics = state_dict["metrics"]
|
||||
|
||||
def named_callbacks(self) -> dict[str, Callback]:
|
||||
def named_callbacks(self) -> Dict[str, Callback]:
|
||||
"""Retrieve a collection of callbacks where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
return _named_collection(self.callbacks)
|
||||
|
||||
def named_loggers(self) -> dict[str, LogWriter]:
|
||||
def named_loggers(self) -> Dict[str, LogWriter]:
|
||||
"""Retrieve a collection of loggers where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
@@ -328,16 +329,13 @@ def _wrap_context(obj):
|
||||
yield obj
|
||||
|
||||
|
||||
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
|
||||
def _named_collection(seq: Sequence[T]) -> Dict[str, T]:
|
||||
"""Convert a list into a dict, where each item is named with its type."""
|
||||
res = {}
|
||||
retry_cnt: collections.Counter = collections.Counter()
|
||||
for item in seq:
|
||||
typename = type(item).__name__.lower()
|
||||
if typename not in res:
|
||||
res[typename] = item
|
||||
else:
|
||||
# names are auto-labelled as earlystop1, earlystop2, ...
|
||||
for retry in range(1, 1000):
|
||||
if f"{typename}{retry}" not in res:
|
||||
res[f"{typename}{retry}"] = item
|
||||
key = typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}"
|
||||
retry_cnt[typename] += 1
|
||||
res[key] = item
|
||||
return res
|
||||
|
||||
@@ -63,15 +63,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
|
||||
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to validate the policy once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to evaluate the policy on test environment once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -82,15 +82,15 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
value = np.mean(value)
|
||||
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
|
||||
|
||||
def log_dict(self, data: dict[str, Any]) -> None:
|
||||
def log_dict(self, data: Dict[str, Any]) -> None:
|
||||
for name, value in data.items():
|
||||
self.log(name, value)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
def state_dict(self) -> Dict:
|
||||
"""Return a checkpoint of current vessel state."""
|
||||
return {"policy": self.policy.state_dict()}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
"""Restore a checkpoint from a previously saved state dict."""
|
||||
self.policy.load_state_dict(state_dict["policy"])
|
||||
|
||||
@@ -125,7 +125,7 @@ class TrainingVessel(TrainingVesselBase):
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
|
||||
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
):
|
||||
self.simulator_fn = simulator_fn # type: ignore
|
||||
self.state_interpreter = state_interpreter
|
||||
@@ -161,7 +161,7 @@ class TrainingVessel(TrainingVesselBase):
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
Update the policy on the collected replay buffer.
|
||||
"""
|
||||
@@ -182,7 +182,7 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
@@ -191,7 +191,7 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
|
||||
Reference in New Issue
Block a user