1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Order execution open source (#1447)

* Waiting for bin data

* Complete readme

* CI

* Add inst filter by time

* Update qlib/data/dataset/processor.py

* typo

* Fix time filter bug

* Add Filter and set Universe

* Complete data pipeline

* Fix Provider Logger Info Args

* Add DQN; a minor bugfix in ppo reward.

* update readme. modify assertion logic in strategy check.

* Fix Doc issues and fix black

* Fix pylint Error

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Huoran Li
2023-03-13 12:06:28 +08:00
committed by GitHub
parent f98e04ca9d
commit 653c082e7a
24 changed files with 742 additions and 42 deletions

2
.gitignore vendored
View File

@@ -27,6 +27,8 @@ examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/
examples/rl_order_execution/data/
examples/rl_order_execution/outputs/
*.egg-info/

View File

@@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = Avg15minLoader(
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
)
super().__init__(
instruments=instruments,

View File

@@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
label: day
feature: 1min
# with label as reference
inst_processor:
inst_processors:
feature:
- class: Resample1minProcessor
module_path: features_sample.py

View File

@@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
feature_15min: 1min
feature_day: day
# with label as reference
inst_processor:
inst_processors:
feature_15min:
- class: ResampleNProcessor
module_path: features_resample_N.py

View File

@@ -0,0 +1,100 @@
# RL Example for Order Execution
This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.
## Data Processing
### Get Data
```
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
```
### Generate Pickle-Style Data
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
When finished, the structure under `data/` should be:
```
data
├── bin
├── orders
├── pickle
└── pickle_dataframe
```
## Training
Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:
- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".
The main differece between these two methods is their reward functions. Please see their config files for details.
Take OPDS as an example, to run the training workflow, run:
```
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
```
Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).
## Backtest
Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:
1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.
The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.
In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.
### Gap between backtest and training pipeline's testing
It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
This is because different simulators are used to simulate market conditions during training and backtesting.
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
No matter what the amount of the order is, it can be completely executed.
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.
If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
In order to do this:
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`
```yaml
...
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
weight_file: PATH/TO/CHECKPOINT
module_path: qlib.rl.order_execution.policy
...
```
## Benchmarks (TBD)
To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:
| **Model** | **PA mean with std.** |
|-----------------------------|-----------------------|
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
| PPO | -1.0935 ± 0.0922 |
| TWAP | ≈ 0.0 ± 0.0 |
The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.

View File

@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/opds/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/opds/

View File

@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/ppo/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/ppo/

View File

@@ -0,0 +1,29 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/twap/

View File

@@ -0,0 +1,61 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 4.0
scale: 0.01
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/opds
checkpoint_every_n_iters: 1

View File

@@ -0,0 +1,62 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
kwargs:
max_step: 8
start_time_index: 0
end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/ppo
checkpoint_every_n_iters: 1

View File

@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from joblib import Parallel, delayed
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)

View File

@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import yaml
import argparse
import os
import shutil
from copy import deepcopy
from qlib.contrib.data.highfreq_provider import HighFreqProvider
loader = yaml.FullLoader
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="config.yml")
parser.add_argument("-d", "--dest", type=str, default=".")
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
args = parser.parse_args()
conf = yaml.load(open(args.config), Loader=loader)
for k, v in conf.items():
if isinstance(v, dict) and "path" in v:
v["path"] = os.path.join(args.dest, v["path"])
provider = HighFreqProvider(**conf)
# Gen dataframe
if "feature_conf" in conf:
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
if "backtest_conf" in conf:
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))
provider.feature_conf["path"] = os.path.splitext(provider.feature_conf["path"])[0] + "/"
provider.backtest_conf["path"] = os.path.splitext(provider.backtest_conf["path"])[0] + "/"
# Split by date
if args.split == "date" or args.split == "both":
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")
# Split by stock
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
shutil.rmtree("stat/", ignore_errors=True)

View File

@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
order_all = order_all[order_all["amount"] > 0.0]
order_all["order_type"] = 0
order_all = order_all.drop(columns=["$volume0"])
order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")]
order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")]
order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")]
order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")]
for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")):
path = OUTPUT_PATH / tag
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)

View File

@@ -0,0 +1,15 @@
import pickle
import os
import pandas as pd
from tqdm import tqdm
for tag in ["test", "valid"]:
files = os.listdir(os.path.join("data/orders/", tag))
dfs = []
for f in tqdm(files):
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
df = df.drop(["$close0"], axis=1)
dfs.append(df)
total_df = pd.concat(dfs)
pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb"))

View File

@@ -0,0 +1,77 @@
# start & end time for training/validation/test datasets
start_time: !!str &start 2020-01-01
end_time: !!str &end 2021-12-31
train_end_time: !!str &tend 2021-06-30
valid_start_time: !!str &vstart 2021-07-01
valid_end_time: !!str &vend 2021-09-30
test_start_time: !!str &tstart 2021-10-01
# the instrument set
instruments: &ins csi300s19_22
# qlib related configuration
qlib_conf:
provider_uri:
5min: ./data/bin # path to generated qlib bin
redis_port: 233
feature_conf:
path: ./data/pickle/feature.pkl # output path of feature
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: HighFreqGeneralHandler
module_path: qlib.contrib.data.highfreq_handler
kwargs:
start_time: *start
end_time: *end
fit_start_time: *start
fit_end_time: *tend
instruments: *ins
day_length: 240 # how many minutes in one trading day
freq: 5min
columns: ["$open", "$high", "$low", "$close"]
infer_processors:
- class: HighFreqNorm
module_path: qlib.contrib.data.highfreq_processor
kwargs:
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
norm_groups:
price: 8
volume: 2
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]
backtest_conf:
path: ./data/pickle/backtest.pkl # output path of backtest
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: HighFreqGeneralBacktestHandler
module_path: qlib.contrib.data.highfreq_handler
kwargs:
start_time: *start
end_time: *end
instruments: *ins
day_length: 240
freq: 5min
columns: ["$close", "$volume"]
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]
freq: 5min

View File

@@ -56,7 +56,7 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -71,7 +71,7 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
@@ -152,7 +152,7 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -167,7 +167,7 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
super().__init__(

View File

@@ -115,6 +115,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
day_length=240,
freq="1min",
columns=["$open", "$high", "$low", "$close", "$vwap"],
inst_processors=None,
):
self.day_length = day_length
self.columns = columns
@@ -128,6 +129,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -257,6 +259,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
day_length=240,
freq="1min",
columns=["$close", "$vwap", "$volume"],
inst_processors=None,
):
self.day_length = day_length
self.columns = set(columns)
@@ -266,6 +269,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -311,6 +315,7 @@ class HighFreqOrderHandler(DataHandlerLP):
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
inst_processors=None,
drop_raw=True,
):
@@ -323,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"inst_processors": inst_processors,
},
}
super().__init__(

View File

@@ -128,7 +128,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -137,11 +137,11 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
@@ -160,7 +160,7 @@ class HighFreqProvider:
with open(path[:-4] + "test.pkl", "wb") as f:
pkl.dump(testset, f)
res = [data[i] for i in datasets]
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_data(self, config, datasets=["train", "valid", "test"]):
@@ -170,7 +170,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -179,18 +179,18 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
res = dataset.prepare(datasets)
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_dataset(self, config):
@@ -200,21 +200,21 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
with open(path, "rb") as f:
dataset = pkl.load(f)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.prepare(["train", "valid", "test"])
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
return dataset
@@ -227,15 +227,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")
@@ -268,15 +268,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")

View File

@@ -153,7 +153,7 @@ class QlibDataLoader(DLWParser):
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processor: dict = None,
inst_processors: Union[dict, list] = None,
):
"""
Parameters
@@ -167,16 +167,19 @@ class QlibDataLoader(DLWParser):
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
inst_processors: dict | list
If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
If inst_processors is a list, then it will be applied to all groups.
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
self.inst_processors = inst_processors if inst_processors is not None else {}
assert isinstance(
self.inst_processors, (dict, list)
), f"inst_processors(={self.inst_processors}) must be dict or list"
super().__init__(config)
@@ -187,8 +190,8 @@ class QlibDataLoader(DLWParser):
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
self.inst_processors
), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"
def load_group_df(
self,
@@ -208,9 +211,10 @@ class QlibDataLoader(DLWParser):
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
inst_processors = (
self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
)
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
from typing import Union, Text, Optional
import numpy as np
import pandas as pd
@@ -11,6 +11,8 @@ from ...constant import EPS
from .utils import fetch_df_by_index
from ...utils.serial import Serializable
from ...utils.paral import datetime_groupby_apply
from qlib.data.inst_processor import InstProcessor
from qlib.data import D
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
@@ -378,3 +380,42 @@ class HashStockFormat(Processor):
from .storage import HashingStockStorage # pylint: disable=C0415
return HashingStockStorage.from_df(df)
class TimeRangeFlt(InstProcessor):
"""
This is a filter to filter stock.
Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.)
WARNING: It may induce leakage!!!
"""
def __init__(
self,
start_time: Optional[Union[pd.Timestamp, str]] = None,
end_time: Optional[Union[pd.Timestamp, str]] = None,
freq: str = "day",
):
"""
Parameters
----------
start_time : Optional[Union[pd.Timestamp, str]]
The data must start earlier (or equal) than `start_time`
None indicates data will not be filtered based on `start_time`
end_time : Optional[Union[pd.Timestamp, str]]
similar to start_time
freq : str
The frequency of the calendar
"""
# Align to calendar before filtering
cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
self.start_time = None if start_time is None else cal[0]
self.end_time = None if end_time is None else cal[-1]
def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):
if (
df.empty
or (self.start_time is None or df.index.min() <= self.start_time)
and (self.end_time is None or df.index.max() >= self.end_time)
):
return df
return df.head(0)

View File

@@ -357,7 +357,10 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
if not output_path.exists():
os.makedirs(output_path)
res.to_csv(output_path / "summary.csv")
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
return res

View File

@@ -12,11 +12,11 @@ import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy
from qlib.rl.trainer.trainer import Trainer
__all__ = ["AllOne", "PPO"]
__all__ = ["AllOne", "PPO", "DQN"]
# baselines #
@@ -158,6 +158,56 @@ class PPO(PPOPolicy):
set_weight(self, Trainer.get_policy_state_dict(weight_file))
DQNModel = PPOActor # Reuse PPOActor.
class DQN(DQNPolicy):
"""A wrapper of tianshou DQNPolicy.
Differences:
- Auto-create model network. Supports discrete action space only.
- Support a ``weight_file`` that supports loading checkpoint.
"""
def __init__(
self,
network: nn.Module,
obs_space: gym.Space,
action_space: gym.Space,
lr: float,
weight_decay: float = 0.0,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
clip_loss_grad: bool = False,
weight_file: Optional[Path] = None,
) -> None:
assert isinstance(action_space, Discrete)
model = DQNModel(network, action_space.n)
optimizer = torch.optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
super().__init__(
model,
optimizer,
discount_factor=discount_factor,
estimation_step=estimation_step,
target_update_freq=target_update_freq,
reward_normalization=reward_normalization,
is_double=is_double,
clip_loss_grad=clip_loss_grad,
)
if weight_file is not None:
set_weight(self, Trainer.get_policy_state_dict(weight_file))
# utilities: these should be put in a separate (common) file. #

View File

@@ -70,7 +70,19 @@ class PPOReward(Reward[SAOEState]):
def reward(self, simulator_state: SAOEState) -> float:
if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6:
vwap_price = cast(dict, simulator_state.metrics)["trade_price"]
if simulator_state.history_exec["deal_amount"].sum() == 0.0:
vwap_price = cast(
float,
np.average(simulator_state.history_exec["market_price"]),
)
else:
vwap_price = cast(
float,
np.average(
simulator_state.history_exec["market_price"],
weights=simulator_state.history_exec["deal_amount"],
),
)
twap_price = simulator_state.backtest_data.get_deal_price().mean()
if simulator_state.order.direction == OrderDir.SELL:

View File

@@ -7,6 +7,7 @@ import collections
from types import GeneratorType
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
import warnings
import numpy as np
import pandas as pd
import torch
@@ -137,7 +138,12 @@ class SAOEStateAdapter:
exec_vol[idx - last_step_range[0]] = order.deal_amount
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
if exec_vol.sum() > self.position + 1.0:
warnings.warn(
f"Sum of execution volume is {exec_vol.sum()} which is larger than "
f"position + 1.0 = {self.position} + 1.0 = {self.position + 1.0}. "
f"All execution volume is scaled down linearly to ensure that their sum does not position."
)
exec_vol *= self.position / (exec_vol.sum())
market_volume = cast(