mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
19 Commits
you-n-g-pa
...
mini_proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
949d96d768 | ||
|
|
597359f98f | ||
|
|
75aae820e8 | ||
|
|
558603beca | ||
|
|
157481abd1 | ||
|
|
9d7a0f032a | ||
|
|
58f9eed3c9 | ||
|
|
8f1e28c43f | ||
|
|
e7c660f0d4 | ||
|
|
2752bdc92c | ||
|
|
687edd79d0 | ||
|
|
ba705d39e0 | ||
|
|
a53f59cdf7 | ||
|
|
8e063828f9 | ||
|
|
86f08e47e8 | ||
|
|
8199822ca0 | ||
|
|
1b9915501c | ||
|
|
c65c598bde | ||
|
|
fb5779a64c |
@@ -8,7 +8,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 360
|
||||
timeout-minutes: 720
|
||||
# we may retry for 3 times for `Unit tests with Pytest`
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -29,7 +29,9 @@ jobs:
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
pip install --upgrade cython numpy pip
|
||||
python -m pip install --upgrade pip
|
||||
# python -m pip is necessary to upgrade pip.
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Downloads dependencies data
|
||||
@@ -50,7 +52,7 @@ jobs:
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 120
|
||||
timeout_minutes: 240
|
||||
max_attempts: 3
|
||||
command: |
|
||||
cd tests
|
||||
|
||||
13
README.md
13
README.md
@@ -176,6 +176,19 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
### Get with module
|
||||
```bash
|
||||
# get 1d data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
### Get from source
|
||||
|
||||
```bash
|
||||
# get 1d data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
@@ -51,7 +51,7 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
|
||||
Qlib Format Dataset
|
||||
-------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,79 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,97 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,104 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -35,13 +35,13 @@ task:
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
colsample_bytree: 0.9
|
||||
learning_rate: 0.1
|
||||
subsample: 0.9
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_leaves: 250
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
102
examples/benchmarks/MLP/workflow_config_mlp_Alpha158_csi500.yaml
Normal file
102
examples/benchmarks/MLP/workflow_config_mlp_Alpha158_csi500.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
pt_model_kwargs:
|
||||
input_dim: 360
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -74,10 +74,15 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
- About the datasets
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully desgined by human (a.k.a feature engineering)
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully designed by human (a.k.a feature engineering)
|
||||
- Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
## Results on CSI500
|
||||
@@ -102,16 +107,21 @@ python run_all_model.py run 3 lightgbm Alpha158 csi500 # for models with random
|
||||
```
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha158 | 0.0377±0.00 | 0.3860±0.00 | 0.0448±0.00 | 0.4675±0.00 | 0.1151±0.00 | 1.3884±0.00 | -0.0898±0.00 |
|
||||
| Linear | Alpha158 | 0.0332±0.00 | 0.3044±0.00 | 0.0462±0.00 | 0.4326±0.00 | 0.0382±0.00 | 0.1723±0.00 | -0.4876±0.00 |
|
||||
| MLP | Alpha158 | 0.0229±0.01 | 0.2181±0.05 | 0.0360±0.00 | 0.3409±0.02 | 0.0043±0.02 | 0.0602±0.27 | -0.2184±0.04 |
|
||||
| LightGBM | Alpha158 | 0.0399±0.00 | 0.4065±0.00 | 0.0482±0.00 | 0.5101±0.00 | 0.1284±0.00 | 1.5650±0.00 | -0.0635±0.00 |
|
||||
| CatBoost | Alpha158 | 0.0345±0.00 | 0.2855±0.00 | 0.0417±0.00 | 0.3740±0.00 | 0.0496±0.00 | 0.5977±0.00 | -0.1496±0.00 |
|
||||
| DoubleEnsemble | Alpha158 | 0.0380±0.00 | 0.3659±0.00 | 0.0442±0.00 | 0.4324±0.00 | 0.0382±0.00 | 0.1723±0.00 | -0.4876±0.00 |
|
||||
|
||||
### Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| MLP | Alpha360 | 0.0258±0.00 | 0.2021±0.02 | 0.0426±0.00 | 0.3840±0.02 | 0.0022±0.02 | 0.0301±0.26 | -0.2064±0.02 |
|
||||
| LightGBM | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00 | 0.7658±0.02 | -0.1880±0.00 |
|
||||
|
||||
| CatBoost | Alpha360 | 0.0382±0.00 | 0.3229±0.00 | 0.0489±0.00 | 0.4649±0.00 | 0.0297±0.00 | 0.4227±0.02 | -0.1499±0.01 |
|
||||
| DoubleEnsemble | Alpha360 | 0.0361±0.00 | 0.3092±0.00 | 0.0499±0.00 | 0.4793±0.00 | 0.0382±0.00 | 0.1723±0.02 | -0.4876±0.00 |
|
||||
|
||||
# Contributing
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def get_exchange(
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
@@ -70,10 +70,10 @@ def get_exchange(
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
deal_price: Union[str, Tuple[str, str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
- (<buy_price>, <sell_price>): Tuple[str, str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import time
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
@@ -23,7 +24,6 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DecisionType = TypeVar("DecisionType")
|
||||
|
||||
|
||||
@@ -182,8 +182,8 @@ class OrderHelper:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
start_time=None if start_time is None else pd.Timestamp(start_time),
|
||||
end_time=None if end_time is None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
@@ -249,7 +249,7 @@ class IdxTradeRange(TradeRange):
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
def __init__(self, start_time: str | time, end_time: str | time) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
@@ -259,13 +259,13 @@ class TradeRangeByTime(TradeRange):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : str | time
|
||||
e.g. "9:30"
|
||||
end_time : str
|
||||
end_time : str | time
|
||||
e.g. "14:30"
|
||||
"""
|
||||
self.start_time = pd.Timestamp(start_time).time()
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
|
||||
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
@@ -535,7 +535,12 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
|
||||
@@ -32,7 +32,7 @@ class Exchange:
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
@@ -448,9 +448,9 @@ class Exchange:
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: Optional[str] = "sum",
|
||||
) -> float:
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
def get_deal_price(
|
||||
self,
|
||||
@@ -459,7 +459,7 @@ class Exchange:
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: Optional[str] = "ts_data_last",
|
||||
) -> float:
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
@@ -472,7 +472,7 @@ class Exchange:
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return cast(float, deal_price)
|
||||
return deal_price
|
||||
|
||||
def get_factor(
|
||||
self,
|
||||
@@ -832,8 +832,11 @@ class Exchange:
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
trade_price = cast(
|
||||
float,
|
||||
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
|
||||
)
|
||||
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
|
||||
@@ -484,6 +484,7 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
self.inner_strategy.post_exe_step(inner_exe_res)
|
||||
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
|
||||
@@ -259,79 +259,119 @@ class Alpha158(DataHandlerLP):
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
@@ -339,12 +379,15 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -352,6 +395,7 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -359,6 +403,8 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -366,6 +412,8 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
|
||||
@@ -137,8 +137,7 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
@@ -162,3 +161,249 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
names += ["$factor0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqOrderHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_ifinf = "If(IsInf({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
def get_normalized_vwap_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(
|
||||
template_if.format("$close", template_ifinf.format("$close", price_field)),
|
||||
template_fillnan.format("$close"),
|
||||
)
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_vwap_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_vwap_price_feature("$vwap", 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [get_normalized_price_feature("$bid", 0)]
|
||||
fields += [get_normalized_price_feature("$ask", 0)]
|
||||
names += ["$bid", "$ask"]
|
||||
|
||||
fields += [get_normalized_price_feature("$bid", 240)]
|
||||
fields += [get_normalized_price_feature("$ask", 240)]
|
||||
names += ["$bid_1", "$ask_1"]
|
||||
|
||||
# calculate and fill nan with 0
|
||||
|
||||
def get_volume_feature(volume_field, shift=0):
|
||||
template_gzero = "If(Ge({0}, 0), {0}, 0)"
|
||||
if shift == 0:
|
||||
feature_ops = template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsInf({0}), 0, {0})".format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(volume_field)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
feature_ops = template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsInf({0}), 0, {0})".format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
f"Ref({{0}}, {shift})/Ref(DayLast(Mean({{0}}, 7200)), 240)".format(volume_field)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_volume_feature("$volume", 0)]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [get_volume_feature("$volume", 240)]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [get_volume_feature("$bidV", 0)]
|
||||
fields += [get_volume_feature("$bidV1", 0)]
|
||||
fields += [get_volume_feature("$bidV3", 0)]
|
||||
fields += [get_volume_feature("$bidV5", 0)]
|
||||
fields += [get_volume_feature("$askV", 0)]
|
||||
fields += [get_volume_feature("$askV1", 0)]
|
||||
fields += [get_volume_feature("$askV3", 0)]
|
||||
fields += [get_volume_feature("$askV5", 0)]
|
||||
names += ["$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"]
|
||||
|
||||
fields += [get_volume_feature("$bidV", 240)]
|
||||
fields += [get_volume_feature("$bidV1", 240)]
|
||||
fields += [get_volume_feature("$bidV3", 240)]
|
||||
fields += [get_volume_feature("$bidV5", 240)]
|
||||
fields += [get_volume_feature("$askV", 240)]
|
||||
fields += [get_volume_feature("$askV1", 240)]
|
||||
fields += [get_volume_feature("$askV3", 240)]
|
||||
fields += [get_volume_feature("$askV5", 240)]
|
||||
names += ["$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestOrderHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
template_if.format(
|
||||
template_fillnan.format("$close"),
|
||||
"$vwap",
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$bid"))]
|
||||
names += ["$bid0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$bidV"))]
|
||||
names += ["$bidV0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$ask"))]
|
||||
names += ["$ask0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$askV"))]
|
||||
names += ["$askV0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("($bid + $ask) / 2"))]
|
||||
names += ["$median0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))]
|
||||
names += ["$factor0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$downlimitmarket"))]
|
||||
names += ["$downlimitmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$uplimitmarket"))]
|
||||
names += ["$uplimitmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$highmarket"))]
|
||||
names += ["$highmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$lowmarket"))]
|
||||
names += ["$lowmarket0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
@@ -104,9 +104,9 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
else:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.topk = topk
|
||||
|
||||
@@ -102,11 +102,22 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
self._freq_file_cache = freq
|
||||
return self._freq_file_cache
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
def _read_calendar(self) -> List[CalVT]:
|
||||
# NOTE:
|
||||
# if we want to accelerate partial reading calendar
|
||||
# we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.
|
||||
# Currently, it is not supported for the txt-based calendar
|
||||
|
||||
if not self.uri.exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]
|
||||
|
||||
with self.uri.open("r") as fp:
|
||||
res = []
|
||||
for line in fp.readlines():
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
|
||||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
|
||||
58
qlib/rl/data/exchange_wrapper.py
Normal file
58
qlib/rl/data/exchange_wrapper.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from .pickle_styled import IntradayBacktestData
|
||||
|
||||
|
||||
class QlibIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for Qlib simulator"""
|
||||
|
||||
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
|
||||
super(QlibIntradayBacktestData, self).__init__()
|
||||
self._order = order
|
||||
self._exchange = exchange
|
||||
self._start_time = start_time
|
||||
self._end_time = end_time
|
||||
|
||||
self._deal_price = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_deal_price(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
direction=self._order.direction,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
self._volume = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_volume(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Order: {self._order}, Exchange: {self._exchange}, "
|
||||
f"Start time: {self._start_time}, End time: {self._end_time}"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._deal_price)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self._deal_price
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self._volume
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
@@ -19,19 +19,19 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence, cast
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
@@ -40,7 +40,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
def _infer_processed_data_column_names(shape: int) -> List[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
@@ -87,7 +87,36 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
|
||||
|
||||
class IntradayBacktestData:
|
||||
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
|
||||
data type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_volume(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for simple simulator"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -95,8 +124,10 @@ class IntradayBacktestData:
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
order_dir: int = None,
|
||||
) -> None:
|
||||
super(SimpleIntradayBacktestData, self).__init__()
|
||||
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
@@ -105,13 +136,13 @@ class IntradayBacktestData:
|
||||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
self.order_dir = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
@@ -162,7 +193,14 @@ class IntradayProcessedData:
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
@@ -190,16 +228,20 @@ class IntradayProcessedData:
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
def load_simple_intraday_backtest_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int = None,
|
||||
) -> SimpleIntradayBacktestData:
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
@@ -207,13 +249,19 @@ def load_intraday_backtest_data(
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
order_path: Path,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
@@ -251,7 +299,7 @@ def load_orders(
|
||||
OrderDir(int(row["order_type"])),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return orders
|
||||
|
||||
4
qlib/rl/from_neutrader/__init__.py
Normal file
4
qlib/rl/from_neutrader/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TODO: find a better way to organize contents under this module.
|
||||
20
qlib/rl/from_neutrader/config.py
Normal file
20
qlib/rl/from_neutrader/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config.
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.0
|
||||
trade_unit: Optional[float] = 100.0
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
109
qlib/rl/from_neutrader/feature.py
Normal file
109
qlib/rl/from_neutrader/feature.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import collections
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, pool_size: int = 200):
|
||||
self.pool_size = pool_size
|
||||
self.contents: dict = {}
|
||||
self.keys: collections.deque = collections.deque()
|
||||
|
||||
def put(self, key, item):
|
||||
if self.has(key):
|
||||
self.keys.remove(key)
|
||||
self.keys.append(key)
|
||||
self.contents[key] = item
|
||||
while len(self.contents) > self.pool_size:
|
||||
self.contents.pop(self.keys.popleft())
|
||||
|
||||
def get(self, key):
|
||||
return self.contents[key]
|
||||
|
||||
def has(self, key):
|
||||
return key in self.contents
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
# TODO: We might have the chance to merge them.
|
||||
self.feature_cache = LRUCache()
|
||||
self.backtest_cache = LRUCache()
|
||||
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
|
||||
if backtest:
|
||||
dataset = self.backtest_dataset
|
||||
cache = self.backtest_cache
|
||||
else:
|
||||
dataset = self.feature_dataset
|
||||
cache = self.feature_cache
|
||||
|
||||
if cache.has((start_time, end_time, stock_id)):
|
||||
return cache.get((start_time, end_time, stock_id))
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
cache.put((start_time, end_time, stock_id), data)
|
||||
return data
|
||||
|
||||
|
||||
def init_qlib(config: dict, part: Optional[str] = None) -> None:
|
||||
provider_uri_map = {
|
||||
"day": config["provider_uri_day"].as_posix(),
|
||||
"1min": config["provider_uri_1min"].as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
@@ -3,13 +3,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
@@ -40,7 +40,7 @@ class Interpreter:
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
@@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: "EnvWrapper" | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
@@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
||||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
||||
@@ -5,15 +5,15 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any, List, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
@@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
|
||||
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
@@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
The key list is not full. You can add more if more information is needed by your policy.
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int):
|
||||
def __init__(self, max_step: int) -> None:
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
@@ -165,13 +165,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
{
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_step": self.env.status["cur_step"],
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
}
|
||||
acquiring=state.order.direction == state.order.BUY,
|
||||
cur_step=self.env.status["cur_step"],
|
||||
num_step=self.max_step,
|
||||
target=state.order.amount,
|
||||
position=state.position,
|
||||
)
|
||||
return obs
|
||||
|
||||
@@ -188,7 +186,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | list[float]):
|
||||
def __init__(self, values: int | List[float]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
@@ -203,7 +201,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
"""Convert a continous ratio to deal amount.
|
||||
"""Convert a continuous ratio to deal amount.
|
||||
|
||||
The ratio is relative to TWAP on the remainder of the day.
|
||||
For example, there are 5 steps left, and the left position is 300.
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .interpreter import FullHistoryObs
|
||||
|
||||
__all__ = ["Recurrent"]
|
||||
@@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
|
||||
class Recurrent(nn.Module):
|
||||
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
||||
|
||||
At every timestep the input of policy network is divided into two parts,
|
||||
At every time step the input of policy network is divided into two parts,
|
||||
the public variables and the private variables. which are handled by ``raw_rnn``
|
||||
and ``pri_rnn`` in this network, respectively.
|
||||
|
||||
@@ -33,7 +34,7 @@ class Recurrent(nn.Module):
|
||||
output_dim: int = 32,
|
||||
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
||||
rnn_num_layers: int = 1,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
@@ -62,10 +63,10 @@ class Recurrent(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def _init_extra_branches(self):
|
||||
def _init_extra_branches(self) -> None:
|
||||
pass
|
||||
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||
bs, _, data_dim = obs["data_processed"].size()
|
||||
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
||||
cur_step = obs["cur_step"].long()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import PPOPolicy, BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
@@ -18,29 +19,39 @@ __all__ = ["AllOne", "PPO"]
|
||||
# baselines #
|
||||
|
||||
|
||||
class NonlearnablePolicy(BasePolicy):
|
||||
class NonLearnablePolicy(BasePolicy):
|
||||
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
|
||||
|
||||
This could be moved outside in future.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
|
||||
super().__init__()
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indices: np.ndarray,
|
||||
) -> Batch:
|
||||
pass
|
||||
|
||||
|
||||
class AllOne(NonlearnablePolicy):
|
||||
class AllOne(NonLearnablePolicy):
|
||||
"""Forward returns a batch full of 1.
|
||||
|
||||
Useful when implementing some baselines (e.g., TWAP).
|
||||
"""
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: dict | Batch | np.ndarray = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
|
||||
@@ -48,24 +59,34 @@ class AllOne(NonlearnablePolicy):
|
||||
|
||||
|
||||
class PPOActor(nn.Module):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPOCritic(nn.Module):
|
||||
def __init__(self, extractor: nn.Module):
|
||||
def __init__(self, extractor: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> torch.Tensor:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
|
||||
@@ -93,18 +114,20 @@ class PPO(PPOPolicy):
|
||||
max_grad_norm: float = 100.0,
|
||||
reward_normalization: bool = True,
|
||||
eps_clip: float = 0.3,
|
||||
value_clip: float = True,
|
||||
value_clip: bool = True,
|
||||
vf_coef: float = 1.0,
|
||||
gae_lambda: float = 1.0,
|
||||
max_batchsize: int = 256,
|
||||
max_batch_size: int = 256,
|
||||
deterministic_eval: bool = True,
|
||||
weight_file: Optional[Path] = None,
|
||||
):
|
||||
) -> None:
|
||||
assert isinstance(action_space, Discrete)
|
||||
actor = PPOActor(network, action_space.n)
|
||||
critic = PPOCritic(network)
|
||||
optimizer = torch.optim.Adam(
|
||||
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
|
||||
chain_dedup(actor.parameters(), critic.parameters()),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
super().__init__(
|
||||
actor,
|
||||
@@ -118,7 +141,7 @@ class PPO(PPOPolicy):
|
||||
value_clip=value_clip,
|
||||
vf_coef=vf_coef,
|
||||
gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize,
|
||||
max_batchsize=max_batch_size,
|
||||
deterministic_eval=deterministic_eval,
|
||||
observation_space=obs_space,
|
||||
action_space=action_space,
|
||||
@@ -136,7 +159,7 @@ def auto_device(module: nn.Module) -> torch.device:
|
||||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy, path):
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
@@ -149,7 +172,7 @@ def load_weight(policy, path):
|
||||
policy.load_state_dict(loaded_weight)
|
||||
|
||||
|
||||
def chain_dedup(*iterables):
|
||||
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
|
||||
seen = set()
|
||||
for iterable in iterables:
|
||||
for i in iterable:
|
||||
|
||||
@@ -6,9 +6,10 @@ from __future__ import annotations
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.rl.reward import Reward
|
||||
|
||||
from .simulator_simple import SAOEState, SAOEMetrics
|
||||
from .simulator_simple import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["PAPenaltyReward"]
|
||||
|
||||
|
||||
@@ -1,4 +1,424 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, cast, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import (
|
||||
dataframe_append,
|
||||
get_common_infra,
|
||||
get_portfolio_and_indicator,
|
||||
get_ticks_slice,
|
||||
price_advantage,
|
||||
)
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class DecomposedStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.execute_order: Optional[Order] = None
|
||||
self.execute_result: List[Tuple[Order, float, float, float]] = []
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
# Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside
|
||||
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
|
||||
# the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner
|
||||
# level strategy through this way.
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
self.execute_order = order
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
self.execute_result = execute_result
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
# this logic is copied from FileOrderStrategy
|
||||
def __init__(
|
||||
self,
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> None:
|
||||
super().__init__(common_infra=common_infra)
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
self._instrument = instrument
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._instrument,
|
||||
amount=self._order.amount,
|
||||
direction=self._order.direction,
|
||||
),
|
||||
]
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
|
||||
|
||||
# TODO: move these to the configuration files
|
||||
FINEST_GRANULARITY = "1min"
|
||||
COARSEST_GRANULARITY = "1day"
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain states of the environment.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.update(...) # in step
|
||||
# get states in get_state from maintainer
|
||||
"""
|
||||
|
||||
def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.position = order.amount
|
||||
self._order = order
|
||||
self._time_per_step = time_per_step
|
||||
self._tick_index = tick_index
|
||||
self._twap_price = twap_price
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
inner_executor: BaseExecutor,
|
||||
inner_strategy: DecomposedStrategy,
|
||||
done: bool,
|
||||
all_indicators: dict,
|
||||
) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
num_step = len(execute_result)
|
||||
|
||||
assert execute_order is not None
|
||||
|
||||
if num_step == 0:
|
||||
market_volume = np.array([])
|
||||
market_price = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
else:
|
||||
market_volume = np.array(
|
||||
inner_executor.trade_exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values
|
||||
deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values
|
||||
market_price = trade_value / deal_amount
|
||||
|
||||
datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:]
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self._order,
|
||||
datetime=datetime_list,
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=all_indicators[self._time_per_step].iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
execute_order,
|
||||
execute_order.start_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if done:
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self._order,
|
||||
self._tick_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
|
||||
class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
order (Order):
|
||||
The seed to start an SAOE simulator is an order.
|
||||
time_per_step (str):
|
||||
A string to describe the time granularity of each step. Current support "1min", "30min", and "1day"
|
||||
qlib_config (dict):
|
||||
Configuration used to initialize Qlib.
|
||||
inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]):
|
||||
Function used to get the inner level executor.
|
||||
exchange_config (ExchangeConfig):
|
||||
Configuration used to create the Exchange instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
time_per_step: str, # "1min", "30min", "1day"
|
||||
qlib_config: dict,
|
||||
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
) -> None:
|
||||
assert time_per_step in ("1min", "30min", "1day")
|
||||
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
|
||||
|
||||
self._order = order
|
||||
self._order_date = pd.Timestamp(order.start_time.date())
|
||||
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
|
||||
self._qlib_config = qlib_config
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._time_per_step = time_per_step
|
||||
self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60)
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
|
||||
self._done = False
|
||||
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
self.reset(self._order)
|
||||
|
||||
def reset(self, order: Order) -> None:
|
||||
instrument = order.stock_id
|
||||
|
||||
# TODO: Check this logic. Make sure we need to do this every time we reset the simulator.
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_date=pd.Timestamp(self._order_date),
|
||||
codes=[instrument],
|
||||
)
|
||||
|
||||
# TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment.
|
||||
# TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and
|
||||
# TODO: code between backtesting and training.
|
||||
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
|
||||
self._executor = NestedExecutor(
|
||||
time_per_step=COARSEST_GRANULARITY,
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
)
|
||||
|
||||
exchange = self._inner_executor.trade_exchange
|
||||
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
|
||||
self._ticks_for_order = get_ticks_slice(
|
||||
self._ticks_index,
|
||||
self._order.start_time,
|
||||
self._order.end_time,
|
||||
include_end=True,
|
||||
)
|
||||
|
||||
self._backtest_data = QlibIntradayBacktestData(
|
||||
order=self._order,
|
||||
exchange=exchange,
|
||||
start_time=self._ticks_for_order[0],
|
||||
end_time=self._ticks_for_order[-1],
|
||||
)
|
||||
|
||||
self.twap_price = self._backtest_data.get_deal_price().mean()
|
||||
|
||||
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
|
||||
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
self._iter_strategy(action=None)
|
||||
self._done = False
|
||||
|
||||
self._maintainer = StateMaintainer(
|
||||
order=self._order,
|
||||
time_per_step=self._time_per_step,
|
||||
tick_index=self._ticks_index,
|
||||
twap_price=self.twap_price,
|
||||
)
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield DecomposedStrategy."""
|
||||
assert self._collect_data_loop is not None
|
||||
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, DecomposedStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, DecomposedStrategy)
|
||||
return strategy
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action (float):
|
||||
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
|
||||
"""
|
||||
|
||||
assert not self._done, "Simulator has already done!"
|
||||
|
||||
try:
|
||||
self._iter_strategy(action=action)
|
||||
except StopIteration:
|
||||
self._done = True
|
||||
|
||||
assert self._executor is not None
|
||||
_, all_indicators = get_portfolio_and_indicator(self._executor)
|
||||
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
done=self._done,
|
||||
all_indicators=all_indicators,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self._order,
|
||||
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
|
||||
position=self._maintainer.position,
|
||||
history_exec=self._maintainer.history_exec,
|
||||
history_steps=self._maintainer.history_steps,
|
||||
metrics=self._maintainer.metrics,
|
||||
backtest_data=self._backtest_data,
|
||||
ticks_per_step=self._ticks_per_step,
|
||||
ticks_index=self._ticks_index,
|
||||
ticks_for_order=self._ticks_for_order,
|
||||
)
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@@ -4,18 +4,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Any, TypeVar, cast
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
|
||||
|
||||
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
@@ -33,40 +35,40 @@ class SAOEMetrics(TypedDict):
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: float
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: float
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: float
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: float
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: float
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: float
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: float
|
||||
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
|
||||
position: float
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: float
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: float
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
@@ -87,7 +89,7 @@ class SAOEState(NamedTuple):
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
@@ -114,13 +116,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
If such fine granularity is not needed, use ``ticks_per_step`` to
|
||||
lengthen the ticks for each step.
|
||||
|
||||
In each step, the traded amount are "equally" splitted to each tick,
|
||||
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
|
||||
In each step, the traded amount are "equally" separated to each tick,
|
||||
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
|
||||
and if it's the last step, try to ensure all the amount to be executed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
@@ -140,7 +142,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``, which is the **starting** time of each step."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
twap_price: float
|
||||
@@ -159,15 +161,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
data_dir: Path,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: float | None = None,
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_intraday_backtest_data(
|
||||
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
|
||||
self.backtest_data = load_simple_intraday_backtest_data(
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
@@ -188,9 +196,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
self.market_price: np.ndarray | None = None
|
||||
self.market_vol: np.ndarray | None = None
|
||||
self.market_vol_limit: np.ndarray | None = None
|
||||
self.market_price: Optional[np.ndarray] = None
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
@@ -205,7 +213,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
|
||||
self.market_price = self.market_vol = None # avoid misuse
|
||||
exec_vol = self._split_exec_vol(amount)
|
||||
assert self.market_price is not None and self.market_vol is not None
|
||||
assert self.market_price is not None
|
||||
assert self.market_vol is not None
|
||||
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
@@ -363,7 +372,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=np.sum(market_price * exec_vol),
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / self.order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
|
||||
@@ -386,7 +395,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
|
||||
111
qlib/rl/order_execution/utils.py
Normal file
111
qlib/rl/order_execution/utils.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_date: pd.Timestamp,
|
||||
codes: List[str],
|
||||
cash_limit: float = None,
|
||||
) -> CommonInfrastructure:
|
||||
# need to specify a range here for acceleration
|
||||
if cash_limit is None:
|
||||
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
|
||||
else:
|
||||
trade_account = Account(
|
||||
init_cash=cash_limit,
|
||||
benchmark_config={},
|
||||
pos_type="Position",
|
||||
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
|
||||
)
|
||||
|
||||
exchange = get_exchange(
|
||||
codes=codes,
|
||||
freq="1min",
|
||||
limit_threshold=config.limit_threshold,
|
||||
deal_price=config.deal_price,
|
||||
open_cost=config.open_cost,
|
||||
close_cost=config.close_cost,
|
||||
min_cost=config.min_cost if config.trade_unit is not None else 0,
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold,
|
||||
)
|
||||
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
# dataframe.append is deprecated
|
||||
other_df = pd.DataFrame(other).set_index("datetime")
|
||||
other_df.index.name = "datetime"
|
||||
|
||||
res = pd.concat([df, other_df], axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
else:
|
||||
return np.zeros_like(exec_price)
|
||||
if direction == OrderDir.BUY:
|
||||
res = (1 - exec_price / baseline_price) * 10000
|
||||
elif direction == OrderDir.SELL:
|
||||
res = (exec_price / baseline_price - 1) * 10000
|
||||
else:
|
||||
raise ValueError(f"Unexpected order direction: {direction}")
|
||||
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
|
||||
all_executors = executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
|
||||
return all_portfolio_metrics, all_indicators
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Any, TypeVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
|
||||
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: SimulatorState) -> float:
|
||||
@@ -30,14 +30,15 @@ class Reward(Generic[SimulatorState]):
|
||||
"""Implement this method for your own reward."""
|
||||
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
|
||||
|
||||
def log(self, name, value):
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
assert self.env is not None
|
||||
self.env.logger.add_scalar(name, value)
|
||||
|
||||
|
||||
class RewardCombination(Reward):
|
||||
"""Combination of multiple reward."""
|
||||
|
||||
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
|
||||
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
|
||||
self.rewards = rewards
|
||||
|
||||
def reward(self, simulator_state: Any) -> float:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Generic, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from .seed import InitialStateType
|
||||
|
||||
@@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
|
||||
Simulators are discouraged to use this, because it's prone to induce errors.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@@ -3,17 +3,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Sequence, cast, Any
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import FiniteEnvType, LogWriter
|
||||
|
||||
from .vessel import TrainingVessel
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel
|
||||
|
||||
|
||||
def train(
|
||||
|
||||
@@ -12,7 +12,7 @@ import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -6,13 +6,13 @@ from __future__ import annotations
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, TypeVar, Sequence, cast
|
||||
from typing import Any, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
@@ -12,12 +12,11 @@ from tianshou.env import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -209,6 +208,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
order = np.random.permutation(len(collection))
|
||||
res = [collection[o] for o in order[:size]]
|
||||
_logger.info(
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.",
|
||||
name,
|
||||
len(collection),
|
||||
len(res),
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_queue import *
|
||||
from .env_wrapper import *
|
||||
from .finite_env import *
|
||||
from .log import *
|
||||
from .data_queue import DataQueue
|
||||
from .env_wrapper import EnvWrapper, EnvWrapperStatus
|
||||
from .finite_env import FiniteEnvType, vectorize_env
|
||||
from .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter
|
||||
|
||||
__all__ = [
|
||||
"LogLevel",
|
||||
"DataQueue",
|
||||
"EnvWrapper",
|
||||
"FiniteEnvType",
|
||||
"LogCollector",
|
||||
"LogWriter",
|
||||
"vectorize_env",
|
||||
"ConsoleWriter",
|
||||
"CsvWriter",
|
||||
"EnvWrapperStatus",
|
||||
"LogBuffer",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from typing import TypeVar, Generic, Sequence, cast
|
||||
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
@@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
|
||||
shuffle: bool = True,
|
||||
producer_num_workers: int = 0,
|
||||
queue_maxsize: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
if queue_maxsize == 0:
|
||||
if os.cpu_count() is not None:
|
||||
queue_maxsize = cast(int, os.cpu_count())
|
||||
@@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
|
||||
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
|
||||
self._done = multiprocessing.Value("i", 0)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> DataQueue:
|
||||
self.activate()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value += 1
|
||||
for repeat in range(500):
|
||||
@@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
|
||||
break
|
||||
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
|
||||
|
||||
def get(self, block=True):
|
||||
def get(self, block: bool = True) -> Any:
|
||||
if not hasattr(self, "_first_get"):
|
||||
self._first_get = True
|
||||
if self._first_get:
|
||||
@@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
|
||||
if self._done.value:
|
||||
raise StopIteration # pylint: disable=raise-missing-from
|
||||
|
||||
def put(self, obj, block=True, timeout=None):
|
||||
return self._queue.put(obj, block=block, timeout=timeout)
|
||||
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
|
||||
self._queue.put(obj, block=block, timeout=timeout)
|
||||
|
||||
def mark_as_done(self):
|
||||
def mark_as_done(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value = 1
|
||||
|
||||
def done(self):
|
||||
def done(self) -> int:
|
||||
return self._done.value
|
||||
|
||||
def activate(self):
|
||||
def activate(self) -> DataQueue:
|
||||
if self._activated:
|
||||
raise ValueError("DataQueue can not activate twice.")
|
||||
thread = threading.Thread(target=self._producer, daemon=True)
|
||||
@@ -138,20 +140,20 @@ class DataQueue(Generic[T]):
|
||||
self._activated = True
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
_logger.debug(f"__del__ of {__name__}.DataQueue")
|
||||
self.cleanup()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[Any, None, None]:
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker "
|
||||
"to produce data into data queue before using it. "
|
||||
"You probably have forgotten to use the DataQueue in a with block."
|
||||
"You probably have forgotten to use the DataQueue in a with block.",
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
def _consumer(self):
|
||||
def _consumer(self) -> Generator[Any, None, None]:
|
||||
while True:
|
||||
try:
|
||||
yield self.get()
|
||||
@@ -159,7 +161,7 @@ class DataQueue(Generic[T]):
|
||||
_logger.debug("Data consumer timed-out from get.")
|
||||
return
|
||||
|
||||
def _producer(self):
|
||||
def _producer(self) -> None:
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
|
||||
@@ -4,14 +4,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, Any, Iterable, Iterator, Generic, cast
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
|
||||
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
@@ -28,7 +29,7 @@ class InfoDict(TypedDict):
|
||||
|
||||
aux_info: dict
|
||||
"""Any information depends on auxiliary info collector."""
|
||||
log: dict[str, Any]
|
||||
log: Dict[str, Any]
|
||||
"""Collected by LogCollector."""
|
||||
|
||||
|
||||
@@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
|
||||
|
||||
cur_step: int
|
||||
done: bool
|
||||
initial_state: Any | None
|
||||
initial_state: Optional[Any]
|
||||
obs_history: list
|
||||
action_history: list
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
@@ -97,11 +99,11 @@ class EnvWrapper(
|
||||
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
seed_iterator: Iterable[InitialStateType] | None,
|
||||
reward_fn: Reward | None = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
|
||||
logger: LogCollector | None = None,
|
||||
):
|
||||
seed_iterator: Optional[Iterable[InitialStateType]],
|
||||
reward_fn: Reward = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
|
||||
logger: LogCollector = None,
|
||||
) -> None:
|
||||
# Assign weak reference to wrapper.
|
||||
#
|
||||
# Use weak reference here, because:
|
||||
@@ -135,11 +137,11 @@ class EnvWrapper(
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
def action_space(self) -> Space:
|
||||
return self.action_interpreter.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> Space:
|
||||
return self.state_interpreter.observation_space
|
||||
|
||||
def reset(self, **kwargs: Any) -> ObsType:
|
||||
@@ -191,7 +193,7 @@ class EnvWrapper(
|
||||
self.seed_iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
|
||||
"""Environment step.
|
||||
|
||||
See the code along with comments to get a sequence of things happening here.
|
||||
@@ -245,5 +247,5 @@ class EnvWrapper(
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self):
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
||||
@@ -11,11 +11,10 @@ from __future__ import annotations
|
||||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Any, Set, Callable, Type
|
||||
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
from qlib.typehint import Literal
|
||||
@@ -32,11 +31,11 @@ __all__ = [
|
||||
"vectorize_env",
|
||||
]
|
||||
|
||||
|
||||
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
||||
T = Union[dict, list, tuple, np.ndarray]
|
||||
|
||||
|
||||
def fill_invalid(obj):
|
||||
def fill_invalid(obj: int | float | bool | T) -> T:
|
||||
if isinstance(obj, (int, float, bool)):
|
||||
return fill_invalid(np.array(obj))
|
||||
if hasattr(obj, "dtype"):
|
||||
@@ -55,11 +54,11 @@ def fill_invalid(obj):
|
||||
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
||||
|
||||
|
||||
def is_invalid(arr):
|
||||
if hasattr(arr, "dtype"):
|
||||
def is_invalid(arr: int | float | bool | T) -> bool:
|
||||
if isinstance(arr, np.ndarray):
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
return np.isnan(arr).all()
|
||||
return (np.iinfo(arr.dtype).max == arr).all()
|
||||
return cast(bool, cast(np.ndarray, np.iinfo(arr.dtype).max == arr).all())
|
||||
if isinstance(arr, dict):
|
||||
return all(is_invalid(o) for o in arr.values())
|
||||
if isinstance(arr, (list, tuple)):
|
||||
@@ -140,44 +139,44 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
self._collector_guarded: bool = False
|
||||
|
||||
def _reset_alive_envs(self):
|
||||
def _reset_alive_envs(self) -> None:
|
||||
if not self._alive_env_ids:
|
||||
# starting or running out
|
||||
self._alive_env_ids = set(range(self.env_num))
|
||||
|
||||
# to workaround with tianshou's buffer and batch
|
||||
def _set_default_obs(self, obs):
|
||||
def _set_default_obs(self, obs: Any) -> None:
|
||||
if obs is not None and self._default_obs is None:
|
||||
self._default_obs = copy.deepcopy(obs)
|
||||
|
||||
def _set_default_info(self, info):
|
||||
def _set_default_info(self, info: Any) -> None:
|
||||
if info is not None and self._default_info is None:
|
||||
self._default_info = copy.deepcopy(info)
|
||||
|
||||
def _set_default_rew(self, rew):
|
||||
def _set_default_rew(self, rew: Any) -> None:
|
||||
if rew is not None and self._default_rew is None:
|
||||
self._default_rew = copy.deepcopy(rew)
|
||||
|
||||
def _get_default_obs(self):
|
||||
def _get_default_obs(self) -> Any:
|
||||
return copy.deepcopy(self._default_obs)
|
||||
|
||||
def _get_default_info(self):
|
||||
def _get_default_info(self) -> Any:
|
||||
return copy.deepcopy(self._default_info)
|
||||
|
||||
def _get_default_rew(self):
|
||||
def _get_default_rew(self) -> Any:
|
||||
return copy.deepcopy(self._default_rew)
|
||||
|
||||
# END
|
||||
|
||||
@staticmethod
|
||||
def _postproc_env_obs(obs):
|
||||
def _postproc_env_obs(obs: Any) -> Optional[Any]:
|
||||
# reserved for shmem vector env to restore empty observation
|
||||
if obs is None or check_nan_observation(obs):
|
||||
return None
|
||||
return obs
|
||||
|
||||
@contextmanager
|
||||
def collector_guard(self):
|
||||
def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
|
||||
"""Guard the collector. Recommended to guard every collect.
|
||||
|
||||
This guard is for two purposes.
|
||||
@@ -207,7 +206,10 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
for logger in self._logger:
|
||||
logger.on_env_all_done()
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(
|
||||
self,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
assert not self._zombie
|
||||
|
||||
# Check whether it's guarded by collector_guard()
|
||||
@@ -219,23 +221,23 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
RuntimeWarning,
|
||||
)
|
||||
|
||||
id = self._wrap_id(id)
|
||||
wrapped_id = self._wrap_id(id)
|
||||
self._reset_alive_envs()
|
||||
|
||||
# ask super to reset alive envs and remap to current index
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
obs = [None] * len(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = [i for i in wrapped_id if i in self._alive_env_ids]
|
||||
obs = [None] * len(wrapped_id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
if request_id:
|
||||
for i, o in zip(request_id, super().reset(request_id)):
|
||||
obs[id2idx[i]] = self._postproc_env_obs(o)
|
||||
|
||||
for i, o in zip(id, obs):
|
||||
for i, o in zip(wrapped_id, obs):
|
||||
if o is None and i in self._alive_env_ids:
|
||||
self._alive_env_ids.remove(i)
|
||||
|
||||
# logging
|
||||
for i, o in zip(id, obs):
|
||||
for i, o in zip(wrapped_id, obs):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_reset(i, obs)
|
||||
@@ -248,19 +250,23 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
obs[i] = self._get_default_obs()
|
||||
|
||||
if not self._alive_env_ids:
|
||||
# comment this line so that the env becomes indisposable
|
||||
# comment this line so that the env becomes indispensable
|
||||
# self.reset()
|
||||
self._zombie = True
|
||||
raise StopIteration
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step(self, action, id=None):
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
result = [[None, None, False, None] for _ in range(len(id))]
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = [[None, None, False, None] for _ in range(len(wrapped_id))]
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
@@ -270,7 +276,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
|
||||
|
||||
# logging
|
||||
for i, r in zip(id, result):
|
||||
for i, r in zip(wrapped_id, result):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
@@ -287,7 +293,8 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
if r[3] is None:
|
||||
result[i][3] = self._get_default_info()
|
||||
|
||||
return list(map(np.stack, zip(*result)))
|
||||
ret = list(map(np.stack, zip(*result)))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
|
||||
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
||||
@@ -306,7 +313,7 @@ def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
logger: LogWriter | List[LogWriter],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
|
||||
|
||||
@@ -350,7 +357,7 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
"shmem": FiniteShmemVectorEnv,
|
||||
|
||||
@@ -21,7 +21,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -65,13 +65,13 @@ class LogCollector:
|
||||
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
|
||||
"""
|
||||
|
||||
_logged: dict[str, tuple[int, Any]]
|
||||
_logged: Dict[str, Tuple[int, Any]]
|
||||
_min_loglevel: int
|
||||
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self._min_loglevel = int(min_loglevel)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Clear all collected contents."""
|
||||
self._logged = {}
|
||||
|
||||
@@ -104,7 +104,10 @@ class LogCollector:
|
||||
self._add_metric(name, scalar, loglevel)
|
||||
|
||||
def add_array(
|
||||
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
|
||||
self,
|
||||
name: str,
|
||||
array: np.ndarray | pd.DataFrame | pd.Series,
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
"""Add an array with name into logging."""
|
||||
if loglevel < self._min_loglevel:
|
||||
@@ -127,7 +130,7 @@ class LogCollector:
|
||||
|
||||
self._add_metric(name, obj, loglevel)
|
||||
|
||||
def logs(self) -> dict[str, np.ndarray]:
|
||||
def logs(self) -> Dict[str, np.ndarray]:
|
||||
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
|
||||
|
||||
|
||||
@@ -154,16 +157,16 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
active_env_ids: Set[int]
|
||||
"""Active environment ids in vector env."""
|
||||
|
||||
episode_lengths: dict[int, int]
|
||||
episode_lengths: Dict[int, int]
|
||||
"""Map from environment id to episode length."""
|
||||
|
||||
episode_rewards: dict[int, list[float]]
|
||||
episode_rewards: Dict[int, List[float]]
|
||||
"""Map from environment id to episode total reward."""
|
||||
|
||||
episode_logs: dict[int, list]
|
||||
episode_logs: Dict[int, list]
|
||||
"""Map from environment id to episode logs."""
|
||||
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self.loglevel = loglevel
|
||||
|
||||
self.global_step = 0
|
||||
@@ -207,11 +210,12 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
# These are runtime infos.
|
||||
# Though they are loaded, I don't think it really helps.
|
||||
self.active_env_ids = state_dict["active_env_ids"]
|
||||
self.episode_lenghts = state_dict["episode_lengths"]
|
||||
self.episode_lengths = state_dict["episode_lengths"]
|
||||
self.episode_rewards = state_dict["episode_rewards"]
|
||||
self.episode_logs = state_dict["episode_logs"]
|
||||
|
||||
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
|
||||
@staticmethod
|
||||
def aggregation(array: Sequence[Any], name: str | None = None) -> Any:
|
||||
"""Aggregation function from step-wise to episode-wise.
|
||||
|
||||
If it's a sequence of float, take the mean.
|
||||
@@ -229,7 +233,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
else:
|
||||
return array[0]
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
"""This is triggered at the end of each trajectory.
|
||||
|
||||
Parameters
|
||||
@@ -242,7 +246,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
Logged contents for every steps.
|
||||
"""
|
||||
|
||||
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
|
||||
def log_step(self, reward: float, contents: Dict[str, Any]) -> None:
|
||||
"""This is triggered at each step.
|
||||
|
||||
Parameters
|
||||
@@ -265,7 +269,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
# TODO: reward can be a list of list for MARL
|
||||
self.episode_rewards[env_id].append(rew)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
values: Dict[str, Any] = {}
|
||||
|
||||
for key, (loglevel, value) in info["log"].items():
|
||||
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
|
||||
@@ -393,11 +397,11 @@ class ConsoleWriter(LogWriter):
|
||||
def __init__(
|
||||
self,
|
||||
log_every_n_episode: int = 20,
|
||||
total_episodes: int | None = None,
|
||||
total_episodes: int = None,
|
||||
float_format: str = ":.4f",
|
||||
counter_format: str = ":4d",
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(loglevel)
|
||||
# TODO: support log_every_n_step
|
||||
self.log_every_n_episode = log_every_n_episode
|
||||
@@ -412,15 +416,15 @@ class ConsoleWriter(LogWriter):
|
||||
|
||||
# FIXME: save & reload
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
# Clear average meters
|
||||
self.metric_counts: dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: dict[str, float] = defaultdict(float)
|
||||
self.metric_counts: Dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: Dict[str, float] = defaultdict(float)
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# Aggregate step-wise to episode-wise
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
@@ -429,7 +433,7 @@ class ConsoleWriter(LogWriter):
|
||||
|
||||
# Generate log contents and track them in average-meter.
|
||||
# This should be done at every step, regardless of periodic or not.
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
@@ -441,7 +445,7 @@ class ConsoleWriter(LogWriter):
|
||||
# Only log periodically or at the end
|
||||
self.console_logger.info(self.generate_log_message(logs))
|
||||
|
||||
def generate_log_message(self, logs: dict[str, float]) -> str:
|
||||
def generate_log_message(self, logs: Dict[str, float]) -> str:
|
||||
if self.prefix:
|
||||
msg_prefix = self.prefix + " "
|
||||
else:
|
||||
@@ -471,29 +475,29 @@ class CsvWriter(LogWriter):
|
||||
|
||||
SUPPORTED_TYPES = (float, str, pd.Timestamp)
|
||||
|
||||
all_records: list[dict[str, Any]]
|
||||
all_records: List[Dict[str, Any]]
|
||||
|
||||
# FIXME: save & reload
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.all_records = []
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
if isinstance(value, self.SUPPORTED_TYPES):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
|
||||
0
qlib/run/__init__.py
Normal file
0
qlib/run/__init__.py
Normal file
9
qlib/run/get_data.py
Normal file
9
qlib/run/get_data.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
@@ -2,14 +2,14 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
@@ -207,8 +207,18 @@ class BaseStrategy:
|
||||
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
"""
|
||||
A hook for doing sth after the corresponding executor finished its execution.
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
the execution result
|
||||
"""
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy, metaclass=ABCMeta):
|
||||
"""RL-based strategy"""
|
||||
|
||||
def __init__(
|
||||
@@ -229,14 +239,14 @@ class RLStrategy(BaseStrategy):
|
||||
self.policy = policy
|
||||
|
||||
|
||||
class RLIntStrategy(RLStrategy):
|
||||
class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
|
||||
"""(RL)-based (Strategy) with (Int)erpreter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
"""Commonly used types."""
|
||||
|
||||
import sys
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = ["Literal", "TypedDict", "final"]
|
||||
|
||||
@@ -11,3 +13,51 @@ if sys.version_info >= (3, 8):
|
||||
from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module
|
||||
else:
|
||||
from typing_extensions import Literal, TypedDict, final
|
||||
|
||||
|
||||
class InstDictConf(TypedDict):
|
||||
"""
|
||||
InstDictConf is a Dict-based config to describe an instance
|
||||
|
||||
case 1)
|
||||
{
|
||||
'class': 'ClassName',
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
'model_path': path, # It is optional if module is given in the class
|
||||
}
|
||||
case 2)
|
||||
{
|
||||
'class': <The class it self>,
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
}
|
||||
"""
|
||||
|
||||
# class: str # because class is a keyword of Python. We have to comment it
|
||||
kwargs: dict # It is optional. {} will be used if not given
|
||||
module_path: str # It is optional if module is given in the class
|
||||
|
||||
|
||||
InstConf = Union[InstDictConf, str, object, Path]
|
||||
"""
|
||||
InstConf is a type to describe an instance; it will be passed into init_instance_by_config for Qlib
|
||||
|
||||
config : Union[str, dict, object, Path]
|
||||
|
||||
InstDictConf example.
|
||||
please refer to the docs of InstDictConf
|
||||
|
||||
str example.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, "ClassName")() will be used.
|
||||
3) specify module path with class name
|
||||
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
|
||||
object example:
|
||||
instance of accept_types
|
||||
|
||||
Path example:
|
||||
specify a pickle object
|
||||
- it will be treated like 'file:///<path to pickle file>/obj.pkl'
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ import re
|
||||
import sys
|
||||
import copy
|
||||
import json
|
||||
from qlib.typehint import InstConf
|
||||
import yaml
|
||||
import redis
|
||||
import bisect
|
||||
@@ -291,7 +292,11 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
|
||||
:param module_path:
|
||||
:return:
|
||||
:raises: ModuleNotFoundError
|
||||
"""
|
||||
if module_path is None:
|
||||
raise ModuleNotFoundError("None is passed in as parameters as module_path")
|
||||
|
||||
if isinstance(module_path, ModuleType):
|
||||
module = module_path
|
||||
else:
|
||||
@@ -324,7 +329,7 @@ def split_module_path(module_path: str) -> Tuple[str, str]:
|
||||
return m_path, cls
|
||||
|
||||
|
||||
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
def get_callable_kwargs(config: InstConf, default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class/func and kwargs from config info
|
||||
|
||||
@@ -343,6 +348,10 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
|
||||
-------
|
||||
(type, dict):
|
||||
the class/func object and it's arguments.
|
||||
|
||||
Raises
|
||||
------
|
||||
ModuleNotFoundError
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
key = "class" if "class" in config else "func"
|
||||
@@ -376,7 +385,7 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object, Path], # TODO: use a user-defined type to replace this Union.
|
||||
config: InstConf,
|
||||
default_module=None,
|
||||
accept_types: Union[type, Tuple[type]] = (),
|
||||
try_kwargs: Dict = {},
|
||||
@@ -387,31 +396,8 @@ def init_instance_by_config(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : Union[str, dict, object]
|
||||
dict example.
|
||||
case 1)
|
||||
{
|
||||
'class': 'ClassName',
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
case 2)
|
||||
{
|
||||
'class': <The class it self>,
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
}
|
||||
str example.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, "ClassName")() will be used.
|
||||
3) specify module path with class name
|
||||
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
Path example:
|
||||
specify a pickle object
|
||||
- it will be treated like 'file:///<path to pickle file>/obj.pkl'
|
||||
config : InstConf
|
||||
|
||||
default_module : Python module
|
||||
Optional. It should be a python module.
|
||||
NOTE: the "module_path" will be override by `module` arguments
|
||||
@@ -518,7 +504,7 @@ def remove_fields_space(fields: [list, str, tuple]):
|
||||
"""
|
||||
if isinstance(fields, str):
|
||||
return fields.replace(" ", "")
|
||||
return [i.replace(" ", "") for i in fields if isinstance(i, str)]
|
||||
return [i.replace(" ", "") if isinstance(i, str) else str(i) for i in fields]
|
||||
|
||||
|
||||
def normalize_cache_fields(fields: [list, tuple]):
|
||||
|
||||
@@ -271,7 +271,7 @@ class LocIndexer:
|
||||
if isinstance(_indexing, IndexData):
|
||||
_indexing = _indexing.data
|
||||
assert _indexing.ndim == 1
|
||||
if _indexing.dtype != np.bool:
|
||||
if _indexing.dtype != bool:
|
||||
_indexing = np.array(list(index.index(i) for i in _indexing))
|
||||
else:
|
||||
_indexing = index.index(_indexing)
|
||||
@@ -431,7 +431,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
|
||||
# The code below could be simpler like methods in __getattribute__
|
||||
def __invert__(self):
|
||||
return self.__class__(~self.data.astype(np.bool), *self.indices)
|
||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN."""
|
||||
|
||||
@@ -575,6 +575,44 @@ class QlibRecorder:
|
||||
"""
|
||||
self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)
|
||||
|
||||
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
|
||||
"""
|
||||
Log a local file or directory as an artifact of the currently active run
|
||||
|
||||
- If `active recorder` exists: it will set tags through the active recorder.
|
||||
- If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and set the tags under it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_path : str
|
||||
Path to the file to write.
|
||||
artifact_path : Optional[str]
|
||||
If provided, the directory in ``artifact_uri`` to write to.
|
||||
"""
|
||||
self.get_exp(start=True).get_recorder(start=True).log_artifact(local_path, artifact_path)
|
||||
|
||||
def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Download an artifact file or directory from a run to a local directory if applicable,
|
||||
and return a local path for it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Relative source path to the desired artifact.
|
||||
dst_path : Optional[str]
|
||||
Absolute path of the local filesystem destination directory to which to
|
||||
download the specified artifacts. This directory must already exist.
|
||||
If unspecified, the artifacts will either be downloaded to a new
|
||||
uniquely-named directory on the local filesystem.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Local path of desired artifact.
|
||||
"""
|
||||
self.get_exp(start=True).get_recorder(start=True).download_artifact(path, dst_path)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
"""
|
||||
Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.
|
||||
@@ -611,7 +649,7 @@ class RecorderWrapper(Wrapper):
|
||||
expm = getattr(self._provider, "exp_manager")
|
||||
if expm.active_experiment is not None:
|
||||
raise RecorderInitializationError(
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already activated. Otherwise, the experiment stored location will be modified."
|
||||
)
|
||||
self._provider = provider
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False) -> Recorder:
|
||||
"""
|
||||
Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
|
||||
specific recorder. When user does not provide recorder id or name, the method will try to return the current
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
import mlflow
|
||||
import logging
|
||||
import shutil
|
||||
@@ -138,6 +139,19 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_metrics` method.")
|
||||
|
||||
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
|
||||
"""
|
||||
Log a local file or directory as an artifact of the currently active run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_path : str
|
||||
Path to the file to write.
|
||||
artifact_path : Optional[str]
|
||||
If provided, the directory in ``artifact_uri`` to write to.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_metrics` method.")
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
"""
|
||||
Log a batch of tags for the current run.
|
||||
@@ -175,6 +189,28 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
|
||||
|
||||
def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Download an artifact file or directory from a run to a local directory if applicable,
|
||||
and return a local path for it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Relative source path to the desired artifact.
|
||||
dst_path : Optional[str]
|
||||
Absolute path of the local filesystem destination directory to which to
|
||||
download the specified artifacts. This directory must already exist.
|
||||
If unspecified, the artifacts will either be downloaded to a new
|
||||
uniquely-named directory on the local filesystem.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Local path of desired artifact.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
|
||||
|
||||
def list_metrics(self):
|
||||
"""
|
||||
List all the metrics of a recorder.
|
||||
@@ -212,6 +248,14 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
|
||||
use file manager to help maintain the objects in the project.
|
||||
|
||||
Instead of using mlflow directly, we use another interface wrapping mlflow to log experiments.
|
||||
Though it takes extra efforts, but it brings users benefits due to following reasons.
|
||||
- It will be more convenient to change the experiment logging backend without changing any code in upper level
|
||||
- We can provide more convenience to automatically do some extra things and make interface easier. For examples:
|
||||
- Automatically logging the uncommitted code
|
||||
- Automatically logging part of environment variables
|
||||
- User can control several different runs by just creating different Recorder (in mlflow, you always have to switch artifact_uri and pass in run ids frequently)
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
|
||||
@@ -304,6 +348,9 @@ class MLflowRecorder(Recorder):
|
||||
self._log_uncommitted_code()
|
||||
|
||||
self.log_params(**{"cmd-sys.argv": " ".join(sys.argv)}) # log the command to produce current experiment
|
||||
self.log_params(
|
||||
**{k: v for k, v in os.environ.items() if k.startswith("_QLIB_")}
|
||||
) # Log necessary environment variables
|
||||
return run
|
||||
|
||||
def _log_uncommitted_code(self):
|
||||
@@ -398,6 +445,9 @@ class MLflowRecorder(Recorder):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
def log_artifact(self, local_path, artifact_path: Optional[str] = None):
|
||||
self.client.log_artifact(self.id, local_path=local_path, artifact_path=artifact_path)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
@@ -420,6 +470,9 @@ class MLflowRecorder(Recorder):
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return [art.path for art in artifacts]
|
||||
|
||||
def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str:
|
||||
return self.client.download_artifacts(self.id, path, dst_path)
|
||||
|
||||
def list_metrics(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.metrics
|
||||
|
||||
@@ -67,3 +67,10 @@ from qlib.constant import REG_CN
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
```
|
||||
|
||||
## Use Crowd Sourced Data
|
||||
The is also a [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
32
scripts/data_collector/crowd_source/README.md
Normal file
32
scripts/data_collector/crowd_source/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# Crowd Source Data
|
||||
|
||||
## Initiative
|
||||
Public data source like yahoo is flawed, it might miss data for stock which is delisted and it might has data which is wrong. This can introduce survivorship bias into our training process.
|
||||
|
||||
The crowd sourced data is introduced to merged data from multiple data source and cross validate against each other, so that:
|
||||
1. We will have a more complete history record.
|
||||
2. We can identify the anomaly data and apply correction when necessary.
|
||||
|
||||
## Related Repo
|
||||
The raw data is hosted on dolthub repo: https://www.dolthub.com/repositories/chenditc/investment_data
|
||||
|
||||
The processing script and sql is hosted on github repo: https://github.com/chenditc/investment_data
|
||||
|
||||
The pakcaged docker runtime is hosted on dockerhub: https://hub.docker.com/repository/docker/chenditc/investment_data
|
||||
|
||||
## How to use it in qlib
|
||||
### Option 1: Download release bin data
|
||||
User can download data in qlib bin format and use it directly: https://github.com/chenditc/investment_data/releases/tag/20220720
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
### Option 2: Generate qlib data from dolthub
|
||||
Dolthub data will be update daily, so that if user wants to get up to date data, they can dump qlib bin using docker:
|
||||
```
|
||||
docker run -v /<some output directory>:/output -it --rm chenditc/investment_data bash dump_qlib_bin.sh && cp ./qlib_bin.tar.gz /output/
|
||||
```
|
||||
|
||||
## FAQ and other info
|
||||
See: https://github.com/chenditc/investment_data/blob/main/README.md
|
||||
@@ -49,3 +49,7 @@ pythono collector.py collector_data --help
|
||||
|
||||
- interval: 1d
|
||||
- region: CN
|
||||
|
||||
## 免责声明
|
||||
|
||||
本项目仅供学习研究使用,不作为任何行为的指导和建议,由此而引发任何争议和纠纷,与本项目无任何关系
|
||||
|
||||
@@ -36,7 +36,7 @@ pip install -r requirements.txt
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*
|
||||
- `version`: dataset version, value from [`v1`, `v2`], by default `v1`
|
||||
- `v2` end date is *2021-06*, `v1` end date is *2020-09*
|
||||
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- If users want to incrementally update data, they need to use yahoo collector to [collect data from scratch](#collector-yahoofinance-data-to-qlib).
|
||||
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
- `region`: `cn` or `us` or `in`, by default `cn`
|
||||
@@ -62,6 +62,8 @@ pip install -r requirements.txt
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format.
|
||||
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
|
||||
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
|
||||
|
||||
This will download the raw data such as high, low, open, close, adjclose price from yahoo to a local directory. One file per symbol.
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
@@ -99,6 +101,10 @@ pip install -r requirements.txt
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
|
||||
|
||||
This will:
|
||||
1. Normalize high, low, close, open price using adjclose.
|
||||
2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1.
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
@@ -136,6 +142,8 @@ pip install -r requirements.txt
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
|
||||
@@ -5,6 +5,8 @@ from random import randint, choice
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env[int, int]):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.logger = LogCollector()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *args: Any, **kwargs: Any) -> int:
|
||||
self.step_count = 0
|
||||
return 0
|
||||
|
||||
def step(self, action: int):
|
||||
def step(self, action: int) -> Tuple[int, float, bool, dict]:
|
||||
self.logger.reset()
|
||||
|
||||
self.logger.add_scalar("reward", 42.0)
|
||||
@@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
|
||||
|
||||
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
@@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
|
||||
|
||||
|
||||
class SimpleSimulator(Simulator[int, float, float]):
|
||||
def __init__(self, initial: int, **kwargs) -> None:
|
||||
def __init__(self, initial: int, **kwargs: Any) -> None:
|
||||
super(SimpleSimulator, self).__init__(initial, **kwargs)
|
||||
self.initial = float(initial)
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
|
||||
177
tests/rl/test_qlib_simulator.py
Normal file
177
tests/rl/test_qlib_simulator.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
|
||||
|
||||
|
||||
def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:
|
||||
return abs(a - b) <= epsilon
|
||||
|
||||
|
||||
def get_order() -> Order:
|
||||
return Order(
|
||||
stock_id="SH600000",
|
||||
amount=TOTAL_POSITION,
|
||||
direction=OrderDir.BUY,
|
||||
start_time=pd.Timestamp("2019-03-04 09:30:00"),
|
||||
end_time=pd.Timestamp("2019-03-04 14:29:00"),
|
||||
)
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
|
||||
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
|
||||
return NestedExecutor(
|
||||
time_per_step=time_per_step,
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
verbose=False,
|
||||
trade_type=SimulatorExecutor.TT_SERIAL,
|
||||
generate_report=False,
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
),
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
)
|
||||
|
||||
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
|
||||
|
||||
# fmt: off
|
||||
qlib_config = {
|
||||
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
|
||||
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
|
||||
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
|
||||
"feature_columns_today": [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
|
||||
],
|
||||
"feature_columns_yesterday": [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=("$ask == 0", "$bid == 0"),
|
||||
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
volume_threshold={
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
cash_limit=None,
|
||||
generate_report=False,
|
||||
)
|
||||
|
||||
return SingleAssetOrderExecutionQlib(
|
||||
order=order,
|
||||
time_per_step="30min",
|
||||
qlib_config=qlib_config,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_first_step():
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
|
||||
assert state.position == TOTAL_POSITION
|
||||
|
||||
AMOUNT = 300.0
|
||||
simulator.step(AMOUNT)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
|
||||
assert state.position == TOTAL_POSITION - AMOUNT
|
||||
assert len(state.history_exec) == 30
|
||||
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
|
||||
|
||||
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
|
||||
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
|
||||
assert (state.history_exec["amount"] == AMOUNT / 30).all()
|
||||
assert (state.history_exec["deal_amount"] == AMOUNT / 30).all()
|
||||
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
|
||||
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
|
||||
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
|
||||
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
|
||||
|
||||
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
|
||||
assert state.history_steps["amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["ffr"].iloc[0] == 1.0
|
||||
assert is_close(
|
||||
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
|
||||
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_stop_twap() -> None:
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
NUM_STEPS = 7
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(TOTAL_POSITION / NUM_STEPS)
|
||||
|
||||
HISTORY_STEP_LENGTH = 30 * NUM_STEPS
|
||||
state = simulator.get_state()
|
||||
assert len(state.history_exec) == HISTORY_STEP_LENGTH
|
||||
|
||||
assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()
|
||||
assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)
|
||||
assert is_close(state.history_steps["position"].iloc[-1], 0.0)
|
||||
assert is_close(state.position, 0.0)
|
||||
assert is_close(state.metrics["ffr"], 1.0)
|
||||
|
||||
assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean())
|
||||
assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
|
||||
assert is_close(state.metrics["trade_price"], state.metrics["market_price"])
|
||||
assert is_close(state.metrics["pa"], 0.0)
|
||||
|
||||
assert simulator.done()
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_interpreter() -> None:
|
||||
NUM_EXECUTION = 3
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
position_history = []
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(interpreter_action(state, 1))
|
||||
state = simulator.get_state()
|
||||
position_history.append(state.position)
|
||||
|
||||
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simulator_first_step()
|
||||
test_simulator_stop_twap()
|
||||
test_interpreter()
|
||||
@@ -9,7 +9,6 @@ from typing import NamedTuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
|
||||
@@ -17,8 +16,8 @@ from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
@@ -38,7 +37,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
||||
|
||||
|
||||
def test_pickle_data_inspect():
|
||||
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
assert len(data) == 390
|
||||
|
||||
data = pickle_styled.load_intraday_processed_data(
|
||||
|
||||
Reference in New Issue
Block a user