diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index 220453d60..b7041e33a 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -86,12 +86,11 @@ jobs: # W1309: f-string-without-interpolation # E1102: not-callable # E1136: unsubscriptable-object - # FIXME: Due to the version change of Pylint, some code will cause W0719 error after PR 1417. W0719 is temporarily disabled in PR 1417 and should be fixed. # References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). - name: Check Qlib with pylint run: | - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/.gitignore b/.gitignore index 51f6654c3..03f9c8b98 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ examples/estimator/estimator_example/ examples/rl/data/ examples/rl/checkpoints/ examples/rl/outputs/ +examples/rl_order_execution/data/ +examples/rl_order_execution/outputs/ *.egg-info/ diff --git a/examples/benchmarks/LightGBM/multi_freq_handler.py b/examples/benchmarks/LightGBM/multi_freq_handler.py index 07d7ac27c..b3e138192 100644 --- a/examples/benchmarks/LightGBM/multi_freq_handler.py +++ b/examples/benchmarks/LightGBM/multi_freq_handler.py @@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP): fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, - inst_processor=None, + inst_processors=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) data_loader = Avg15minLoader( - config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor + config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors ) super().__init__( instruments=instruments, diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml index 3d0a7859c..6b58ea4bd 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml @@ -18,7 +18,7 @@ data_handler_config: &data_handler_config label: day feature: 1min # with label as reference - inst_processor: + inst_processors: feature: - class: Resample1minProcessor module_path: features_sample.py diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml index 20cf7de6e..11b277ce6 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml @@ -19,7 +19,7 @@ data_handler_config: &data_handler_config feature_15min: 1min feature_day: day # with label as reference - inst_processor: + inst_processors: feature_15min: - class: ResampleNProcessor module_path: features_resample_N.py diff --git a/examples/rl/README.md b/examples/rl/README.md deleted file mode 100644 index e4b4488d0..000000000 --- a/examples/rl/README.md +++ /dev/null @@ -1,60 +0,0 @@ -This folder contains a simple example of how to run Qlib RL. It contains: - -``` -. -├── experiment_config -│ ├── backtest # Backtest config -│ └── training # Training config -├── README.md # Readme (the current file) -└── scripts # Scripts for data pre-processing -``` - -## Data preparation - -Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data: - -``` -azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive -mv qlib_rl_example_data data -``` - -The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run: - -``` -bash scripts/data_pipeline.sh -``` - -After the execution finishes, the `data/` directory should be like: - -``` -data -├── backtest_orders.csv -├── bin -├── csv -├── pickle -├── pickle_dataframe -└── training_order_split -``` - -## Run training - -Run: - -``` -python -m qlib.rl.contrib.train_onpolicy --config_path ./experiment_config/training/config.yml -``` - -After training, checkpoints will be stored under `checkpoints/`. - -## Run backtest - -``` -python -m qlib.rl.contrib.backtest --config_path ./experiment_config/backtest/config.yml -``` - -The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`. - -## Others -The RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic. -But the core part of RL is much simpler than what you see. -To demonstrate the simple core of RL, [a dedicated notebook](./simple_example.ipynb) for RL without business loss is created. diff --git a/examples/rl/experiment_config/backtest/config.yml b/examples/rl/experiment_config/backtest/config.yml deleted file mode 100644 index 418780c2c..000000000 --- a/examples/rl/experiment_config/backtest/config.yml +++ /dev/null @@ -1,57 +0,0 @@ -order_file: ./data/backtest_orders.csv -start_time: "9:45" -end_time: "14:44" -qlib: - provider_uri_1min: ./data/bin - feature_root_dir: ./data/pickle - feature_columns_today: [ - "$open", "$high", "$low", "$close", "$vwap", "$volume", - ] - feature_columns_yesterday: [ - "$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1", - ] -exchange: - limit_threshold: ['$close == 0', '$close == 0'] - deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"] - volume_threshold: - all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"] - buy: ["current", "$close"] - sell: ["current", "$close"] -strategies: - 30min: - class: TWAPStrategy - module_path: qlib.contrib.strategy.rule_strategy - kwargs: {} - 1day: - class: SAOEIntStrategy - module_path: qlib.rl.order_execution.strategy - kwargs: - state_interpreter: - class: FullHistoryStateInterpreter - module_path: qlib.rl.order_execution.interpreter - kwargs: - max_step: 8 - data_ticks: 240 - data_dim: 6 - processed_data_provider: - class: PickleProcessedDataProvider - module_path: qlib.rl.data.pickle_styled - kwargs: - data_dir: ./data/pickle_dataframe/feature - action_interpreter: - class: CategoricalActionInterpreter - module_path: qlib.rl.order_execution.interpreter - kwargs: - values: 14 - max_step: 8 - network: - class: Recurrent - module_path: qlib.rl.order_execution.network - kwargs: {} - policy: - class: PPO - module_path: qlib.rl.order_execution.policy - kwargs: - lr: 1.0e-4 - weight_file: ./checkpoints/latest.pth -concurrency: 5 diff --git a/examples/rl/scripts/data_pipeline.sh b/examples/rl/scripts/data_pipeline.sh deleted file mode 100644 index c15b8fbe5..000000000 --- a/examples/rl/scripts/data_pipeline.sh +++ /dev/null @@ -1,14 +0,0 @@ -# Generate `bin` format data -set -e -python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min - -# Generate pickle format data -python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml -if [ -e stat/ ]; then - rm -r stat/ -fi -python scripts/collect_pickle_dataframe.py - -# Sample orders -python scripts/gen_training_orders.py -python scripts/gen_backtest_orders.py diff --git a/examples/rl/scripts/gen_backtest_orders.py b/examples/rl/scripts/gen_backtest_orders.py deleted file mode 100644 index 1857f6447..000000000 --- a/examples/rl/scripts/gen_backtest_orders.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -import os -import pandas as pd -import numpy as np -import pickle - -parser = argparse.ArgumentParser() -parser.add_argument("--seed", type=int, default=20220926) -parser.add_argument("--num_order", type=int, default=10) -args = parser.parse_args() - -np.random.seed(args.seed) - -path = os.path.join("data", "pickle", "backtesttest.pkl") -df = pickle.load(open(path, "rb")).reset_index() -df["date"] = df["datetime"].dt.date.astype("datetime64") - -instruments = sorted(set(df["instrument"])) - -# TODO: The example is expected to be able to handle data containing missing values. -# TODO: Currently, we just simply skip dates that contain missing data. We will add -# TODO: this feature in the future. -skip_dates = {} -for instrument in instruments: - csv_df = pd.read_csv(os.path.join("data", "csv", f"{instrument}.csv")) - csv_df = csv_df[csv_df["close"].isna()] - dates = set([str(d).split(" ")[0] for d in csv_df["date"]]) - skip_dates[instrument] = dates - -df_list = [] -for instrument in instruments: - print(instrument) - - cur_df = df[df["instrument"] == instrument] - - dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]])) - dates = [date for date in dates if date not in skip_dates[instrument]] - - n = args.num_order - df_list.append( - pd.DataFrame( - { - "date": sorted(np.random.choice(dates, size=n, replace=False)), - "instrument": [instrument] * n, - "amount": np.random.randint(low=3, high=11, size=n) * 100.0, - "order_type": np.random.randint(low=0, high=2, size=n), - } - ).set_index(["date", "instrument"]), - ) - -total_df = pd.concat(df_list) -total_df.to_csv("data/backtest_orders.csv") diff --git a/examples/rl/scripts/gen_training_orders.py b/examples/rl/scripts/gen_training_orders.py deleted file mode 100644 index 5dd1e96c6..000000000 --- a/examples/rl/scripts/gen_training_orders.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -import os -import pandas as pd -import numpy as np -import pickle - -parser = argparse.ArgumentParser() -parser.add_argument("--seed", type=int, default=20220926) -parser.add_argument("--stock", type=str, default="AAPL") -parser.add_argument("--train_size", type=int, default=10) -parser.add_argument("--valid_size", type=int, default=2) -parser.add_argument("--test_size", type=int, default=2) -args = parser.parse_args() - -np.random.seed(args.seed) - -os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True) - -for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)): - path = os.path.join("data", "pickle", f"backtest{group}.pkl") - df = pickle.load(open(path, "rb")).reset_index() - df["date"] = df["datetime"].dt.date.astype("datetime64") - - dates = sorted(set([str(d).split(" ")[0] for d in df["date"]])) - - data_df = pd.DataFrame( - { - "date": sorted(np.random.choice(dates, size=n, replace=False)), - "instrument": [args.stock] * n, - "amount": np.random.randint(low=3, high=11, size=n) * 100.0, - "order_type": [0] * n, - } - ).set_index(["date", "instrument"]) - - os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True) - pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb")) diff --git a/examples/rl_order_execution/README.md b/examples/rl_order_execution/README.md new file mode 100644 index 000000000..197b1605f --- /dev/null +++ b/examples/rl_order_execution/README.md @@ -0,0 +1,100 @@ +# RL Example for Order Execution + +This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows. + +## Data Processing + +### Get Data + +``` +python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min +``` + +### Generate Pickle-Style Data + +To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish): + +``` +python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml +python scripts/collect_pickle_dataframe.py +python scripts/gen_training_orders.py +python scripts/merge_orders.py +``` + +When finished, the structure under `data/` should be: + +``` +data +├── bin +├── orders +├── pickle +└── pickle_dataframe +``` + +## Training + +Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks: + +- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)". +- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)". + +The main differece between these two methods is their reward functions. Please see their config files for details. + +Take OPDS as an example, to run the training workflow, run: + +``` +python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest +``` + +Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`). + +## Backtest + +Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow: + +1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless. +2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`. + +The backtest result is stored in `outputs/checkpoints/backtest_result.csv`. + +In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`. + +### Gap between backtest and training pipeline's testing + +It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training. +This is because different simulators are used to simulate market conditions during training and backtesting. +In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons. +`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts. +No matter what the amount of the order is, it can be completely executed. +However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used. +It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit). +As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed. + +If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase. +In order to do this: +- Modify the training config. Add the path of the checkpoint you want to use (see following for an example). +- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training` + +```yaml +... +policy: + class: PPO # PPO, DQN + kwargs: + lr: 0.0001 + weight_file: PATH/TO/CHECKPOINT + module_path: qlib.rl.order_execution.policy +... +``` + +## Benchmarks (TBD) + +To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows: + +| **Model** | **PA mean with std.** | +|-----------------------------|-----------------------| +| OPDS (with PPO policy) | 0.4785 ± 0.7815 | +| OPDS (with DQN policy) | -0.0114 ± 0.5780 | +| PPO | -1.0935 ± 0.0922 | +| TWAP | ≈ 0.0 ± 0.0 | + +The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results. diff --git a/examples/rl_order_execution/exp_configs/backtest_opds.yml b/examples/rl_order_execution/exp_configs/backtest_opds.yml new file mode 100755 index 000000000..c1c9b929a --- /dev/null +++ b/examples/rl_order_execution/exp_configs/backtest_opds.yml @@ -0,0 +1,59 @@ +order_file: ./data/orders/test_orders.pkl +start_time: "9:30" +end_time: "14:54" +qlib: + provider_uri_5min: ./data/bin/ + feature_root_dir: ./data/pickle/ + feature_columns_today: [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" + ] + feature_columns_yesterday: [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" + ] +exchange: + limit_threshold: null + deal_price: ["$close", "$close"] + volume_threshold: null +strategies: + 1day: + class: SAOEIntStrategy + kwargs: + data_granularity: 5 + action_interpreter: + class: CategoricalActionInterpreter + kwargs: + max_step: 8 + values: 4 + module_path: qlib.rl.order_execution.interpreter + network: + class: Recurrent + kwargs: {} + module_path: qlib.rl.order_execution.network + policy: + class: PPO # PPO, DQN + kwargs: + lr: 0.0001 + # Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use. + # weight_file: outputs/opds/checkpoints/latest.pth + module_path: qlib.rl.order_execution.policy + state_interpreter: + class: FullHistoryStateInterpreter + kwargs: + data_dim: 5 + data_ticks: 48 + max_step: 8 + processed_data_provider: + class: PickleProcessedDataProvider + kwargs: + data_dir: ./data/pickle_dataframe/feature + module_path: qlib.rl.data.pickle_styled + module_path: qlib.rl.order_execution.interpreter + module_path: qlib.rl.order_execution.strategy + 30min: + class: TWAPStrategy + kwargs: {} + module_path: qlib.contrib.strategy.rule_strategy +concurrency: 16 +output_dir: outputs/opds/ diff --git a/examples/rl_order_execution/exp_configs/backtest_ppo.yml b/examples/rl_order_execution/exp_configs/backtest_ppo.yml new file mode 100755 index 000000000..1298626b5 --- /dev/null +++ b/examples/rl_order_execution/exp_configs/backtest_ppo.yml @@ -0,0 +1,59 @@ +order_file: ./data/orders/test_orders.pkl +start_time: "9:30" +end_time: "14:54" +qlib: + provider_uri_5min: ./data/bin/ + feature_root_dir: ./data/pickle/ + feature_columns_today: [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" + ] + feature_columns_yesterday: [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" + ] +exchange: + limit_threshold: null + deal_price: ["$close", "$close"] + volume_threshold: null +strategies: + 1day: + class: SAOEIntStrategy + kwargs: + data_granularity: 5 + action_interpreter: + class: CategoricalActionInterpreter + kwargs: + max_step: 8 + values: 4 + module_path: qlib.rl.order_execution.interpreter + network: + class: Recurrent + kwargs: {} + module_path: qlib.rl.order_execution.network + policy: + class: PPO # PPO, DQN + kwargs: + lr: 0.0001 + # Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use. + # weight_file: outputs/ppo/checkpoints/latest.pth + module_path: qlib.rl.order_execution.policy + state_interpreter: + class: FullHistoryStateInterpreter + kwargs: + data_dim: 5 + data_ticks: 48 + max_step: 8 + processed_data_provider: + class: PickleProcessedDataProvider + kwargs: + data_dir: ./data/pickle_dataframe/feature + module_path: qlib.rl.data.pickle_styled + module_path: qlib.rl.order_execution.interpreter + module_path: qlib.rl.order_execution.strategy + 30min: + class: TWAPStrategy + kwargs: {} + module_path: qlib.contrib.strategy.rule_strategy +concurrency: 16 +output_dir: outputs/ppo/ diff --git a/examples/rl_order_execution/exp_configs/backtest_twap.yml b/examples/rl_order_execution/exp_configs/backtest_twap.yml new file mode 100755 index 000000000..a797e3fd8 --- /dev/null +++ b/examples/rl_order_execution/exp_configs/backtest_twap.yml @@ -0,0 +1,29 @@ +order_file: ./data/orders/test_orders.pkl +start_time: "9:30" +end_time: "14:54" +qlib: + provider_uri_5min: ./data/bin/ + feature_root_dir: ./data/pickle/ + feature_columns_today: [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" + ] + feature_columns_yesterday: [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" + ] +exchange: + limit_threshold: null + deal_price: ["$close", "$close"] + volume_threshold: null +strategies: + 1day: + class: TWAPStrategy + kwargs: {} + module_path: qlib.contrib.strategy.rule_strategy + 30min: + class: TWAPStrategy + kwargs: {} + module_path: qlib.contrib.strategy.rule_strategy +concurrency: 16 +output_dir: outputs/twap/ diff --git a/examples/rl/experiment_config/training/config.yml b/examples/rl_order_execution/exp_configs/train_opds.yml old mode 100644 new mode 100755 similarity index 66% rename from examples/rl/experiment_config/training/config.yml rename to examples/rl_order_execution/exp_configs/train_opds.yml index 7e50d3eee..c69896474 --- a/examples/rl/experiment_config/training/config.yml +++ b/examples/rl_order_execution/exp_configs/train_opds.yml @@ -1,20 +1,21 @@ simulator: + data_granularity: 5 time_per_step: 30 vol_limit: null env: - concurrency: 1 - parallel_mode: dummy + concurrency: 48 + parallel_mode: shmem action_interpreter: class: CategoricalActionInterpreter kwargs: - values: 14 + values: 4 max_step: 8 module_path: qlib.rl.order_execution.interpreter state_interpreter: class: FullHistoryStateInterpreter kwargs: - data_dim: 6 - data_ticks: 240 + data_dim: 5 + data_ticks: 48 # 48 = 240 min / 5 min max_step: 8 processed_data_provider: class: PickleProcessedDataProvider @@ -25,23 +26,24 @@ state_interpreter: reward: class: PAPenaltyReward kwargs: - penalty: 100.0 + penalty: 4.0 + scale: 0.01 module_path: qlib.rl.order_execution.reward data: source: - order_dir: ./data/training_order_split + order_dir: ./data/orders data_dir: ./data/pickle_dataframe/backtest total_time: 240 - default_start_time: 0 - default_end_time: 240 - proc_data_dim: 6 + default_start_time_index: 0 + default_end_time_index: 235 + proc_data_dim: 5 num_workers: 0 queue_size: 20 network: class: Recurrent module_path: qlib.rl.order_execution.network policy: - class: PPO + class: PPO # PPO, DQN kwargs: lr: 0.0001 module_path: qlib.rl.order_execution.policy @@ -49,11 +51,11 @@ runtime: seed: 42 use_cuda: false trainer: - max_epoch: 2 - repeat_per_collect: 5 - earlystop_patience: 2 - episode_per_collect: 20 - batch_size: 16 - val_every_n_epoch: 1 - checkpoint_path: ./checkpoints + max_epoch: 500 + repeat_per_collect: 25 + earlystop_patience: 50 + episode_per_collect: 10000 + batch_size: 1024 + val_every_n_epoch: 4 + checkpoint_path: ./outputs/opds checkpoint_every_n_iters: 1 diff --git a/examples/rl_order_execution/exp_configs/train_ppo.yml b/examples/rl_order_execution/exp_configs/train_ppo.yml new file mode 100755 index 000000000..d0b272238 --- /dev/null +++ b/examples/rl_order_execution/exp_configs/train_ppo.yml @@ -0,0 +1,62 @@ +simulator: + data_granularity: 5 + time_per_step: 30 + vol_limit: null +env: + concurrency: 48 + parallel_mode: shmem +action_interpreter: + class: CategoricalActionInterpreter + kwargs: + values: 4 + max_step: 8 + module_path: qlib.rl.order_execution.interpreter +state_interpreter: + class: FullHistoryStateInterpreter + kwargs: + data_dim: 5 + data_ticks: 48 # 48 = 240 min / 5 min + max_step: 8 + processed_data_provider: + class: PickleProcessedDataProvider + module_path: qlib.rl.data.pickle_styled + kwargs: + data_dir: ./data/pickle_dataframe/feature + module_path: qlib.rl.order_execution.interpreter +reward: + class: PPOReward + kwargs: + max_step: 8 + start_time_index: 0 + end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1 + module_path: qlib.rl.order_execution.reward +data: + source: + order_dir: ./data/orders + data_dir: ./data/pickle_dataframe/backtest + total_time: 240 + default_start_time_index: 0 + default_end_time_index: 235 + proc_data_dim: 5 + num_workers: 0 + queue_size: 20 +network: + class: Recurrent + module_path: qlib.rl.order_execution.network +policy: + class: PPO # PPO, DQN + kwargs: + lr: 0.0001 + module_path: qlib.rl.order_execution.policy +runtime: + seed: 42 + use_cuda: false +trainer: + max_epoch: 500 + repeat_per_collect: 25 + earlystop_patience: 50 + episode_per_collect: 10000 + batch_size: 1024 + val_every_n_epoch: 4 + checkpoint_path: ./outputs/ppo + checkpoint_every_n_iters: 1 diff --git a/examples/rl/scripts/collect_pickle_dataframe.py b/examples/rl_order_execution/scripts/collect_pickle_dataframe.py old mode 100644 new mode 100755 similarity index 54% rename from examples/rl/scripts/collect_pickle_dataframe.py rename to examples/rl_order_execution/scripts/collect_pickle_dataframe.py index 64dc94bdb..4b02c0d36 --- a/examples/rl/scripts/collect_pickle_dataframe.py +++ b/examples/rl_order_execution/scripts/collect_pickle_dataframe.py @@ -4,10 +4,17 @@ import os import pickle import pandas as pd -from tqdm import tqdm +from joblib import Parallel, delayed os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True) + +def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None: + cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) + cur = cur.set_index(["instrument", "datetime", "date"]) + pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb")) + + for tag in ("backtest", "feature"): df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb")) df = pd.concat(list(df.values())).reset_index() @@ -15,7 +22,5 @@ for tag in ("backtest", "feature"): instruments = sorted(set(df["instrument"])) os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True) - for instrument in tqdm(instruments): - cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) - cur = cur.set_index(["instrument", "datetime", "date"]) - pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb")) + + Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments) diff --git a/examples/rl/scripts/gen_pickle_data.py b/examples/rl_order_execution/scripts/gen_pickle_data.py similarity index 96% rename from examples/rl/scripts/gen_pickle_data.py rename to examples/rl_order_execution/scripts/gen_pickle_data.py index f2dbbf115..75810bddc 100755 --- a/examples/rl/scripts/gen_pickle_data.py +++ b/examples/rl_order_execution/scripts/gen_pickle_data.py @@ -4,6 +4,7 @@ import yaml import argparse import os +import shutil from copy import deepcopy from qlib.contrib.data.highfreq_provider import HighFreqProvider @@ -41,3 +42,5 @@ if __name__ == "__main__": if args.split == "stock" or args.split == "both": provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature") provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest") + + shutil.rmtree("stat/", ignore_errors=True) diff --git a/examples/rl_order_execution/scripts/gen_training_orders.py b/examples/rl_order_execution/scripts/gen_training_orders.py new file mode 100755 index 000000000..5bca0e4ca --- /dev/null +++ b/examples/rl_order_execution/scripts/gen_training_orders.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import numpy as np +import pandas as pd +from tqdm import tqdm +from pathlib import Path + +DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest")) +OUTPUT_PATH = Path(os.path.join("data", "orders")) + + +def generate_order(stock: str, start_idx: int, end_idx: int) -> None: + df = pd.read_pickle(DATA_PATH / f"{stock}.pkl") + df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0) + div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first") + + order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna()) + order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"] + order_all = order_all[order_all["amount"] > 0.0] + order_all["order_type"] = 0 + order_all = order_all.drop(columns=["$volume0"]) + + order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")] + order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")] + order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")] + order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")] + + for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")): + path = OUTPUT_PATH / tag + os.makedirs(path, exist_ok=True) + if len(order) > 0: + order.to_pickle(path / f"{stock}.pkl.target") + + +np.random.seed(1234) +file_list = sorted(os.listdir(DATA_PATH)) +stocks = [f.replace(".pkl", "") for f in file_list] +stocks = sorted(np.random.choice(stocks, size=100, replace=False)) +for stock in tqdm(stocks): + generate_order(stock, 0, 240 // 5 - 1) diff --git a/examples/rl_order_execution/scripts/merge_orders.py b/examples/rl_order_execution/scripts/merge_orders.py new file mode 100755 index 000000000..64a684e07 --- /dev/null +++ b/examples/rl_order_execution/scripts/merge_orders.py @@ -0,0 +1,15 @@ +import pickle +import os +import pandas as pd +from tqdm import tqdm + +for tag in ["test", "valid"]: + files = os.listdir(os.path.join("data/orders/", tag)) + dfs = [] + for f in tqdm(files): + df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb")) + df = df.drop(["$close0"], axis=1) + dfs.append(df) + + total_df = pd.concat(dfs) + pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb")) diff --git a/examples/rl/scripts/pickle_data_config.yml b/examples/rl_order_execution/scripts/pickle_data_config.yml similarity index 63% rename from examples/rl/scripts/pickle_data_config.yml rename to examples/rl_order_execution/scripts/pickle_data_config.yml index 7813f7d38..3d7b2aa04 100755 --- a/examples/rl/scripts/pickle_data_config.yml +++ b/examples/rl_order_execution/scripts/pickle_data_config.yml @@ -1,15 +1,16 @@ # start & end time for training/validation/test datasets start_time: !!str &start 2020-01-01 -end_time: !!str &end 2020-07-31 -train_end_time: !!str &tend 2020-03-31 -valid_start_time: !!str &vstart 2020-04-01 -valid_end_time: !!str &vend 2020-05-31 -test_start_time: !!str &tstart 2020-06-01 +end_time: !!str &end 2021-12-31 +train_end_time: !!str &tend 2021-06-30 +valid_start_time: !!str &vstart 2021-07-01 +valid_end_time: !!str &vend 2021-09-30 +test_start_time: !!str &tstart 2021-10-01 # the instrument set -instruments: &ins all +instruments: &ins csi300s19_22 # qlib related configuration qlib_conf: - provider_uri: ./data/bin # path to generated qlib bin + provider_uri: + 5min: ./data/bin # path to generated qlib bin redis_port: 233 feature_conf: path: ./data/pickle/feature.pkl # output path of feature @@ -26,14 +27,23 @@ feature_conf: fit_end_time: *tend instruments: *ins day_length: 240 # how many minutes in one trading day + freq: 5min + columns: ["$open", "$high", "$low", "$close"] infer_processors: - class: HighFreqNorm module_path: qlib.contrib.data.highfreq_processor kwargs: feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization) norm_groups: - price: 10 + price: 8 volume: 2 + inst_processors: + - class: TimeRangeFlt + module_path: qlib.data.dataset.processor + kwargs: + start_time: "2020-01-01" + end_time: "2021-12-31" + freq: 5min segments: train: !!python/tuple [*start, *tend] valid: !!python/tuple [*vstart, *vend] @@ -51,7 +61,17 @@ backtest_conf: end_time: *end instruments: *ins day_length: 240 + freq: 5min + columns: ["$close", "$volume"] + inst_processors: + - class: TimeRangeFlt + module_path: qlib.data.dataset.processor + kwargs: + start_time: "2020-01-01" + end_time: "2021-12-31" + freq: 5min segments: train: !!python/tuple [*start, *tend] valid: !!python/tuple [*vstart, *vend] test: !!python/tuple [*tstart, *end] +freq: 5min diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index ca3ca5545..ce052f550 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -56,7 +56,7 @@ class Alpha360(DataHandlerLP): fit_start_time=None, fit_end_time=None, filter_pipe=None, - inst_processor=None, + inst_processors=None, **kwargs ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -71,7 +71,7 @@ class Alpha360(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, - "inst_processor": inst_processor, + "inst_processors": inst_processors, }, } @@ -152,7 +152,7 @@ class Alpha158(DataHandlerLP): fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, - inst_processor=None, + inst_processors=None, **kwargs ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -167,7 +167,7 @@ class Alpha158(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, - "inst_processor": inst_processor, + "inst_processors": inst_processors, }, } super().__init__( diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index f69f8195f..638fbf0e8 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -44,7 +44,7 @@ class HighFreqHandler(DataHandlerLP): names = [] template_if = "If(IsNull({1}), {0}, {1})" - template_paused = "Select(Gt($hx_paused_num, 1.001), {0})" + template_paused = "Select(Gt($paused_num, 1.001), {0})" def get_normalized_price_feature(price_field, shift=0): # norm with the close price of 237th minute of yesterday. @@ -115,6 +115,7 @@ class HighFreqGeneralHandler(DataHandlerLP): day_length=240, freq="1min", columns=["$open", "$high", "$low", "$close", "$vwap"], + inst_processors=None, ): self.day_length = day_length self.columns = columns @@ -128,6 +129,7 @@ class HighFreqGeneralHandler(DataHandlerLP): "config": self.get_feature_config(), "swap_level": False, "freq": freq, + "inst_processors": inst_processors, }, } super().__init__( @@ -257,6 +259,7 @@ class HighFreqGeneralBacktestHandler(DataHandler): day_length=240, freq="1min", columns=["$close", "$vwap", "$volume"], + inst_processors=None, ): self.day_length = day_length self.columns = set(columns) @@ -266,6 +269,7 @@ class HighFreqGeneralBacktestHandler(DataHandler): "config": self.get_feature_config(), "swap_level": False, "freq": freq, + "inst_processors": inst_processors, }, } super().__init__( @@ -311,6 +315,7 @@ class HighFreqOrderHandler(DataHandlerLP): learn_processors=[], fit_start_time=None, fit_end_time=None, + inst_processors=None, drop_raw=True, ): @@ -323,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP): "config": self.get_feature_config(), "swap_level": False, "freq": "1min", + "inst_processors": inst_processors, }, } super().__init__( @@ -482,7 +488,7 @@ class HighFreqBacktestOrderHandler(DataHandler): names = [] template_if = "If(IsNull({1}), {0}, {1})" - template_paused = "Select(Gt($hx_paused_num, 1.001), {0})" + template_paused = "Select(Gt($paused_num, 1.001), {0})" template_fillnan = "FFillNan({0})" fields += [ template_fillnan.format(template_paused.format("$close")), diff --git a/qlib/contrib/data/highfreq_provider.py b/qlib/contrib/data/highfreq_provider.py index b499cc68e..611e30d86 100644 --- a/qlib/contrib/data/highfreq_provider.py +++ b/qlib/contrib/data/highfreq_provider.py @@ -128,7 +128,7 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - self.logger.info("Dataset exists, load from disk.", __name__) + self.logger.info(f"[{__name__}]Dataset exists, load from disk.") # res = dataset.prepare(['train', 'valid', 'test']) with open(path, "rb") as f: @@ -137,11 +137,11 @@ class HighFreqProvider: res = [data[i] for i in datasets] else: res = data.prepare(datasets) - self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - self.logger.info("Generating dataset", __name__) + self.logger.info(f"[{__name__}]Generating dataset") start_time = time.time() self._prepare_calender_cache() dataset = init_instance_by_config(config) @@ -160,7 +160,7 @@ class HighFreqProvider: with open(path[:-4] + "test.pkl", "wb") as f: pkl.dump(testset, f) res = [data[i] for i in datasets] - self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) + self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}") return res def _gen_data(self, config, datasets=["train", "valid", "test"]): @@ -170,7 +170,7 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - self.logger.info("Dataset exists, load from disk.", __name__) + self.logger.info(f"[{__name__}]Dataset exists, load from disk.") # res = dataset.prepare(['train', 'valid', 'test']) with open(path, "rb") as f: @@ -179,18 +179,18 @@ class HighFreqProvider: res = [data[i] for i in datasets] else: res = data.prepare(datasets) - self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - self.logger.info("Generating dataset", __name__) + self.logger.info(f"[{__name__}]Generating dataset") start_time = time.time() self._prepare_calender_cache() dataset = init_instance_by_config(config) dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) res = dataset.prepare(datasets) - self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__) + self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}") return res def _gen_dataset(self, config): @@ -200,21 +200,21 @@ class HighFreqProvider: raise ValueError("Must specify the path to save the dataset.") from e if os.path.isfile(path): start = time.time() - self.logger.info("Dataset exists, load from disk.", __name__) + self.logger.info(f"[{__name__}]Dataset exists, load from disk.") with open(path, "rb") as f: dataset = pkl.load(f) - self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - self.logger.info("Generating dataset", __name__) + self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") dataset.prepare(["train", "valid", "test"]) - self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}") dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) return dataset @@ -227,15 +227,15 @@ class HighFreqProvider: if os.path.isfile(path + "tmp_dataset.pkl"): start = time.time() - self.logger.info("Dataset exists, load from disk.", __name__) + self.logger.info(f"[{__name__}]Dataset exists, load from disk.") else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - self.logger.info("Generating dataset", __name__) + self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") @@ -268,15 +268,15 @@ class HighFreqProvider: if os.path.isfile(path + "tmp_dataset.pkl"): start = time.time() - self.logger.info("Dataset exists, load from disk.", __name__) + self.logger.info(f"[{__name__}]Dataset exists, load from disk.") else: start = time.time() if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) - self.logger.info("Generating dataset", __name__) + self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__) + self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2fe8f8a63..9b2a6fa32 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -7,6 +7,7 @@ from typing import Callable, Union, Tuple, List, Iterator, Optional import pandas as pd +from qlib.typehint import Literal from ...log import get_module_logger, TimeInspector from ...utils import init_instance_by_config from ...utils.serial import Serializable @@ -49,6 +50,8 @@ class DataHandler(Serializable): - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc` """ + _data: pd.DataFrame # underlying data. + def __init__( self, instruments=None, @@ -155,6 +158,11 @@ class DataHandler(Serializable): """ fetch data from underlying data source + Design motivation: + - providing a unified interface for underlying data. + - Potential to make the interface more friendly. + - User can improve performance when fetching data in this extra layer + Parameters ---------- selector : Union[pd.Timestamp, slice, str] @@ -328,6 +336,9 @@ class DataHandler(Serializable): yield cur_date, self.fetch(selector, **kwargs) +DATA_KEY_TYPE = Literal["raw", "infer", "learn"] + + class DataHandlerLP(DataHandler): """ DataHandler with **(L)earnable (P)rocessor** @@ -353,10 +364,15 @@ class DataHandlerLP(DataHandler): - `drop_raw=True`: this will modify the data inplace on raw data; """ + # based on `self._data`, _infer and _learn are genrated after processors + _infer: pd.DataFrame # data for inference + _learn: pd.DataFrame # data for learning models + # data key - DK_R = "raw" - DK_I = "infer" - DK_L = "learn" + DK_R: DATA_KEY_TYPE = "raw" + DK_I: DATA_KEY_TYPE = "infer" + DK_L: DATA_KEY_TYPE = "learn" + # map data_key to attribute name ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"} # process type @@ -600,7 +616,7 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing - def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: + def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame: if data_key == self.DK_R and self.drop_raw: raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" @@ -613,7 +629,7 @@ class DataHandlerLP(DataHandler): selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, - data_key: str = DK_I, + data_key: DATA_KEY_TYPE = DK_I, squeeze: bool = False, proc_func: Callable = None, ) -> pd.DataFrame: @@ -647,7 +663,7 @@ class DataHandlerLP(DataHandler): proc_func=proc_func, ) - def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: + def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list: """ get the column names @@ -655,7 +671,7 @@ class DataHandlerLP(DataHandler): ---------- col_set : str select a set of meaningful columns.(e.g. features, columns). - data_key : str + data_key : DATA_KEY_TYPE the data to fetch: DK_*. Returns diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index cc9ecd7c4..e9d6f9886 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -153,7 +153,7 @@ class QlibDataLoader(DLWParser): filter_pipe: List = None, swap_level: bool = True, freq: Union[str, dict] = "day", - inst_processor: dict = None, + inst_processors: Union[dict, list] = None, ): """ Parameters @@ -167,16 +167,19 @@ class QlibDataLoader(DLWParser): freq: dict or str If type(config) == dict and type(freq) == str, load config data using freq. If type(config) == dict and type(freq) == dict, load config[] data using freq[] - inst_processor: dict - If inst_processor is not None and type(config) == dict; load config[] data using inst_processor[] + inst_processors: dict | list + If inst_processors is not None and type(config) == dict; load config[] data using inst_processors[] + If inst_processors is a list, then it will be applied to all groups. """ self.filter_pipe = filter_pipe self.swap_level = swap_level self.freq = freq # sample - self.inst_processor = inst_processor if inst_processor is not None else {} - assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict" + self.inst_processors = inst_processors if inst_processors is not None else {} + assert isinstance( + self.inst_processors, (dict, list) + ), f"inst_processors(={self.inst_processors}) must be dict or list" super().__init__(config) @@ -187,8 +190,8 @@ class QlibDataLoader(DLWParser): if _gp not in freq: raise ValueError(f"freq(={freq}) missing group(={_gp})") assert ( - self.inst_processor - ), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty" + self.inst_processors + ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty" def load_group_df( self, @@ -208,9 +211,10 @@ class QlibDataLoader(DLWParser): warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq - df = D.features( - instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, []) + inst_processors = ( + self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) ) + df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) df.columns = names if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index cf4845af8..f7204cf78 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import abc -from typing import Union, Text +from typing import Union, Text, Optional import numpy as np import pandas as pd @@ -11,6 +11,8 @@ from ...constant import EPS from .utils import fetch_df_by_index from ...utils.serial import Serializable from ...utils.paral import datetime_groupby_apply +from qlib.data.inst_processor import InstProcessor +from qlib.data import D def get_group_columns(df: pd.DataFrame, group: Union[Text, None]): @@ -378,3 +380,42 @@ class HashStockFormat(Processor): from .storage import HashingStockStorage # pylint: disable=C0415 return HashingStockStorage.from_df(df) + + +class TimeRangeFlt(InstProcessor): + """ + This is a filter to filter stock. + Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.) + WARNING: It may induce leakage!!! + """ + + def __init__( + self, + start_time: Optional[Union[pd.Timestamp, str]] = None, + end_time: Optional[Union[pd.Timestamp, str]] = None, + freq: str = "day", + ): + """ + Parameters + ---------- + start_time : Optional[Union[pd.Timestamp, str]] + The data must start earlier (or equal) than `start_time` + None indicates data will not be filtered based on `start_time` + end_time : Optional[Union[pd.Timestamp, str]] + similar to start_time + freq : str + The frequency of the calendar + """ + # Align to calendar before filtering + cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq) + self.start_time = None if start_time is None else cal[0] + self.end_time = None if end_time is None else cal[-1] + + def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs): + if ( + df.empty + or (self.start_time is None or df.index.min() <= self.start_time) + and (self.end_time is None or df.index.max() >= self.end_time) + ): + return df + return df.head(0) diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 2818f788c..6fafa9428 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -357,7 +357,10 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram if not output_path.exists(): os.makedirs(output_path) - res.to_csv(output_path / "summary.csv") + + if "pa" in res.columns: + res["pa"] = res["pa"] * 10000.0 # align with training metrics + res.to_csv(output_path / "backtest_result.csv") return res diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 2102ff6ab..a46b587aa 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -12,11 +12,11 @@ import torch import torch.nn as nn from gym.spaces import Discrete from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy from qlib.rl.trainer.trainer import Trainer -__all__ = ["AllOne", "PPO"] +__all__ = ["AllOne", "PPO", "DQN"] # baselines # @@ -158,6 +158,56 @@ class PPO(PPOPolicy): set_weight(self, Trainer.get_policy_state_dict(weight_file)) +DQNModel = PPOActor # Reuse PPOActor. + + +class DQN(DQNPolicy): + """A wrapper of tianshou DQNPolicy. + + Differences: + + - Auto-create model network. Supports discrete action space only. + - Support a ``weight_file`` that supports loading checkpoint. + """ + + def __init__( + self, + network: nn.Module, + obs_space: gym.Space, + action_space: gym.Space, + lr: float, + weight_decay: float = 0.0, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + weight_file: Optional[Path] = None, + ) -> None: + assert isinstance(action_space, Discrete) + + model = DQNModel(network, action_space.n) + optimizer = torch.optim.Adam( + model.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + + super().__init__( + model, + optimizer, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + ) + if weight_file is not None: + set_weight(self, Trainer.get_policy_state_dict(weight_file)) + + # utilities: these should be put in a separate (common) file. # diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index c6acc4394..0dcfd24bb 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -70,7 +70,19 @@ class PPOReward(Reward[SAOEState]): def reward(self, simulator_state: SAOEState) -> float: if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6: - vwap_price = cast(dict, simulator_state.metrics)["trade_price"] + if simulator_state.history_exec["deal_amount"].sum() == 0.0: + vwap_price = cast( + float, + np.average(simulator_state.history_exec["market_price"]), + ) + else: + vwap_price = cast( + float, + np.average( + simulator_state.history_exec["market_price"], + weights=simulator_state.history_exec["deal_amount"], + ), + ) twap_price = simulator_state.backtest_data.get_deal_price().mean() if simulator_state.order.direction == OrderDir.SELL: diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index b6f5e12b2..7e66a1f08 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -7,6 +7,7 @@ import collections from types import GeneratorType from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union +import warnings import numpy as np import pandas as pd import torch @@ -137,7 +138,12 @@ class SAOEStateAdapter: exec_vol[idx - last_step_range[0]] = order.deal_amount if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: - assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" + if exec_vol.sum() > self.position + 1.0: + warnings.warn( + f"Sum of execution volume is {exec_vol.sum()} which is larger than " + f"position + 1.0 = {self.position} + 1.0 = {self.position + 1.0}. " + f"All execution volume is scaled down linearly to ensure that their sum does not position." + ) exec_vol *= self.position / (exec_vol.sum()) market_volume = cast( diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index ea2f0cdec..4908f438f 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -224,7 +224,7 @@ def requests_with_retry(url, retry=5, **kwargs): except Exception as e: log.warning("exception encountered {}".format(e)) continue - raise Exception("ERROR: requests failed!") + raise TimeoutError("ERROR: requests failed!") #################### Parse #################### diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index d0adda66e..ae165ef1f 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -333,7 +333,7 @@ class MLflowExperiment(Experiment): recorder = self._get_recorder(recorder_name=recorder_name) self._client.delete_run(recorder.id) except MlflowException as e: - raise Exception( + raise ValueError( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." ) from e diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 3059eecd1..94d17beaf 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -415,7 +415,7 @@ class MLflowExpManager(ExpManager): raise MlflowException("No valid experiment has been found.") self.client.delete_experiment(experiment.experiment_id) except MlflowException as e: - raise Exception( + raise ValueError( f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct." ) from e diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 4502a6c04..25f465936 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -324,7 +324,7 @@ class MLflowRecorder(Recorder): raise RuntimeError("This recorder is not saved in the local file system.") else: - raise Exception( + raise ValueError( "Please make sure the recorder has been created and started properly before getting artifact uri." ) @@ -464,7 +464,7 @@ class MLflowRecorder(Recorder): if self.artifact_uri is not None: return self.artifact_uri else: - raise Exception( + raise ValueError( "Please make sure the recorder has been created and started properly before getting artifact uri." )