mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
3 Commits
fix_pip_in
...
6cma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f5f3a6af0 | ||
|
|
2f8fc8d28a | ||
|
|
3e9ccd3ad2 |
5
.github/release-drafter.yml
vendored
5
.github/release-drafter.yml
vendored
@@ -14,9 +14,6 @@ categories:
|
||||
label:
|
||||
- 'doc'
|
||||
- 'documentation'
|
||||
- title: '🧹 Maintenance'
|
||||
label:
|
||||
- 'maintenance'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
|
||||
version-resolver:
|
||||
@@ -33,4 +30,4 @@ version-resolver:
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
||||
$CHANGES
|
||||
16
.github/workflows/test_qlib_from_pip.yml
vendored
16
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -19,20 +19,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@@ -60,7 +50,7 @@ jobs:
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
|
||||
21
.github/workflows/test_qlib_from_source.yml
vendored
21
.github/workflows/test_qlib_from_source.yml
vendored
@@ -20,28 +20,18 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install pip==23.0.1
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -139,7 +129,8 @@ jobs:
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
|
||||
18
.github/workflows/test_qlib_from_source_slow.yml
vendored
18
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -20,28 +20,18 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install pip==23.0.1
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
@@ -354,8 +353,6 @@ Here is a list of models built on `Qlib`.
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
- [KRNN based on pytorch](examples/benchmarks/KRNN/)
|
||||
- [Sandwich based on pytorch](examples/benchmarks/Sandwich/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ Here are some example:
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# KRNN
|
||||
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py](https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py)
|
||||
|
||||
|
||||
# Introductions about the settings/configs.
|
||||
* Torch_geometric is used in the original model in FOST, but we didn't use it.
|
||||
* make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.__version__==1.12.1
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
numpy==1.23.4
|
||||
pandas==1.5.2
|
||||
@@ -1,91 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: KRNN
|
||||
module_path: qlib.contrib.model.pytorch_krnn
|
||||
kwargs:
|
||||
fea_dim: 6
|
||||
cnn_dim: 8
|
||||
cnn_kernel_size: 3
|
||||
rnn_dim: 8
|
||||
rnn_dups: 2
|
||||
rnn_layers: 2
|
||||
n_epochs: 200
|
||||
lr: 0.001
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -26,7 +26,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0279±0.00 | 0.2181±0.01 | 0.0421±0.00 | 0.3429±0.01 | 0.0262±0.02 | 0.4133±0.25 | -0.1090±0.03 |
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
|
||||
@@ -68,8 +68,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
| KRNN | Alpha360 | 0.0173±0.01 | 0.1210±0.06 | 0.0270±0.01 | 0.2018±0.04 | -0.0465±0.05 | -0.5415±0.62 | -0.2919±0.13 |
|
||||
| Sandwich | Alpha360 | 0.0258±0.00 | 0.1924±0.04 | 0.0337±0.00 | 0.2624±0.03 | 0.0005±0.03 | 0.0001±0.33 | -0.1752±0.05 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
@@ -136,7 +134,7 @@ If you want to contribute your new models, you can follow the steps below.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# Sandwich
|
||||
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py](https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py)
|
||||
|
||||
|
||||
# Introductions about the settings/configs.
|
||||
* Torch_geometric is used in the original model in FOST, but we didn't use it.
|
||||
make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.version==1.12.1
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
numpy==1.23.4
|
||||
pandas==1.5.2
|
||||
@@ -1,93 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: Sandwich
|
||||
module_path: qlib.contrib.model.pytorch_sandwich
|
||||
kwargs:
|
||||
fea_dim: 6
|
||||
cnn_dim_1: 16
|
||||
cnn_dim_2: 16
|
||||
cnn_kernel_size: 3
|
||||
rnn_dim_1: 8
|
||||
rnn_dim_2: 8
|
||||
rnn_dups: 2
|
||||
rnn_layers: 2
|
||||
n_epochs: 200
|
||||
lr: 0.001
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
.PHONY: clean
|
||||
|
||||
clean:
|
||||
-rm -r *.pkl mlruns || true
|
||||
@@ -34,14 +34,14 @@ class DDGDA:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sim_task_model: Literal["linear", "gbdt"] = "gbdt",
|
||||
sim_task_model: Literal["linear", "gbdt"] = "linear",
|
||||
forecast_model: Literal["linear", "gbdt"] = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
alpha: float = 0.01,
|
||||
alpha: float = 0.0,
|
||||
proxy_hd: str = "handler_proxy.pkl",
|
||||
):
|
||||
"""
|
||||
@@ -116,9 +116,7 @@ class DDGDA:
|
||||
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
|
||||
df_all = {
|
||||
@@ -170,8 +168,7 @@ class DDGDA:
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
# the train_start for training meta model does not necessarily align with final rolling
|
||||
train_start = "2008-01-01" if self.rb_kwargs.get("train_start") is None else self.rb_kwargs.get("train_start")
|
||||
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
proxy_forecast_model_task = {
|
||||
@@ -215,7 +212,7 @@ class DDGDA:
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -8,17 +8,15 @@ The table below shows the performances of different solutions on different forec
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------|---------|------|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.0945|0.5989|0.1069 |0.6495 |0.0857 |1.3682 |-0.0986 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.0983|0.6157|0.1108 |0.6646 |0.0764 |1.1904 |-0.0769 |
|
||||
| RR[LightGBM] |Alpha158 |0.0816|0.5887|0.0912 |0.6263 |0.0771 |1.3196 |-0.0909 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.0878|0.6185|0.0975 |0.6524 |0.1261 |2.0096 |-0.0744 |
|
||||
|------------------|---------|----|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.089|0.577|0.102 |0.627 |0.093 |1.458 |-0.073 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.096|0.636|0.107 |0.677 |0.067 |0.996 |-0.091 |
|
||||
| RR[LightGBM] |Alpha158 |0.082|0.589|0.091 |0.626 |0.077 |1.320 |-0.091 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.085|0.658|0.094 |0.686 |0.115 |1.792 |-0.068 |
|
||||
|
||||
- The label horizon of the `Alpha158` dataset is set to 20.
|
||||
- The rolling time intervals are set to 20 trading days.
|
||||
|
||||
@@ -67,12 +67,11 @@ class RollingBenchmark:
|
||||
def basic_task(self):
|
||||
"""For fast training rolling"""
|
||||
if self.model_type == "gbdt":
|
||||
conf_path = DIRNAME / "workflow_config_lightgbm_Alpha158.yaml"
|
||||
conf_path = DIRNAME.parent.parent / "benchmarks" / "LightGBM" / "workflow_config_lightgbm_Alpha158.yaml"
|
||||
# dump the processed data on to disk for later loading to speed up the processing
|
||||
h_path = DIRNAME / "lightgbm_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
elif self.model_type == "linear":
|
||||
# We use ridge regression to stabilize the performance
|
||||
conf_path = DIRNAME / "workflow_config_linear_Alpha158.yaml"
|
||||
conf_path = DIRNAME.parent.parent / "benchmarks" / "Linear" / "workflow_config_linear_Alpha158.yaml"
|
||||
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
else:
|
||||
raise AssertionError("Model type is not supported!")
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,79 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ridge
|
||||
alpha: 0.05
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.2"
|
||||
__version__ = "0.9.1.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -179,7 +179,7 @@ def get_strategy_executor(
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
@@ -197,12 +197,15 @@ def get_strategy_executor(
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
if isinstance(exchange_kwargs, Exchange):
|
||||
trade_exchange = exchange_kwargs
|
||||
else:
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
|
||||
@@ -56,6 +56,7 @@ def collect_data_loop(
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -74,6 +75,8 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -83,7 +86,8 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
|
||||
@@ -177,7 +177,7 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
@@ -263,6 +263,9 @@ class Exchange:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
if isinstance(limit_threshold, list):
|
||||
assert len(limit_threshold) == 2
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
@@ -325,7 +328,7 @@ class Exchange:
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
if key in ("buy", "all"):
|
||||
@@ -803,7 +806,7 @@ class Exchange:
|
||||
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
|
||||
if limit[0] == "current":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
|
||||
@@ -1,511 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
########################################################################
|
||||
########################################################################
|
||||
########################################################################
|
||||
|
||||
|
||||
class CNNEncoderBase(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel_size, device):
|
||||
"""Build a basic CNN encoder
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
The input dimension
|
||||
output_dim : int
|
||||
The output dimension
|
||||
kernel_size : int
|
||||
The size of convolutional kernels
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.device = device
|
||||
|
||||
# set padding to ensure the same length
|
||||
# it is correct only when kernel_size is odd, dilation is 1, stride is 1
|
||||
self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=(kernel_size - 1) // 2)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
|
||||
# input shape: [batch_size, seq_len*input_dim]
|
||||
# output shape: [batch_size, seq_len, input_dim]
|
||||
x = x.view(x.shape[0], -1, self.input_dim).permute(0, 2, 1).to(self.device)
|
||||
y = self.conv(x) # [batch_size, output_dim, conved_seq_len]
|
||||
y = y.permute(0, 2, 1) # [batch_size, conved_seq_len, output_dim]
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class KRNNEncoderBase(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dup_num, rnn_layers, dropout, device):
|
||||
"""Build K parallel RNNs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
The input dimension
|
||||
output_dim : int
|
||||
The output dimension
|
||||
dup_num : int
|
||||
The number of parallel RNNs
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dup_num = dup_num
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.device = device
|
||||
|
||||
self.rnn_modules = nn.ModuleList()
|
||||
for _ in range(dup_num):
|
||||
self.rnn_modules.append(nn.GRU(input_dim, output_dim, num_layers=self.rnn_layers, dropout=dropout))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input data
|
||||
n_id : torch.Tensor
|
||||
Node indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
|
||||
# input shape: [batch_size, seq_len, input_dim]
|
||||
# output shape: [batch_size, seq_len, output_dim]
|
||||
# [seq_len, batch_size, input_dim]
|
||||
batch_size, seq_len, input_dim = x.shape
|
||||
x = x.permute(1, 0, 2).to(self.device)
|
||||
|
||||
hids = []
|
||||
for rnn in self.rnn_modules:
|
||||
h, _ = rnn(x) # [seq_len, batch_size, output_dim]
|
||||
hids.append(h)
|
||||
# [seq_len, batch_size, output_dim, num_dups]
|
||||
hids = torch.stack(hids, dim=-1)
|
||||
hids = hids.view(seq_len, batch_size, self.output_dim, self.dup_num)
|
||||
hids = hids.mean(dim=3)
|
||||
hids = hids.permute(1, 0, 2)
|
||||
|
||||
return hids
|
||||
|
||||
|
||||
class CNNKRNNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, cnn_input_dim, cnn_output_dim, cnn_kernel_size, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device
|
||||
):
|
||||
"""Build an encoder composed of CNN and KRNN
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cnn_input_dim : int
|
||||
The input dimension of CNN
|
||||
cnn_output_dim : int
|
||||
The output dimension of CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_output_dim : int
|
||||
The output dimension of KRNN
|
||||
rnn_dup_num : int
|
||||
The number of parallel duplicates for KRNN
|
||||
rnn_layers : int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.cnn_encoder = CNNEncoderBase(cnn_input_dim, cnn_output_dim, cnn_kernel_size, device)
|
||||
self.krnn_encoder = KRNNEncoderBase(cnn_output_dim, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input data
|
||||
n_id : torch.Tensor
|
||||
Node indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
cnn_out = self.cnn_encoder(x)
|
||||
krnn_out = self.krnn_encoder(cnn_out)
|
||||
|
||||
return krnn_out
|
||||
|
||||
|
||||
class KRNNModel(nn.Module):
|
||||
def __init__(self, fea_dim, cnn_dim, cnn_kernel_size, rnn_dim, rnn_dups, rnn_layers, dropout, device, **params):
|
||||
"""Build a KRNN model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea_dim : int
|
||||
The feature dimension
|
||||
cnn_dim : int
|
||||
The hidden dimension of CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_dim : int
|
||||
The hidden dimension of KRNN
|
||||
rnn_dups : int
|
||||
The number of parallel duplicates
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=fea_dim,
|
||||
cnn_output_dim=cnn_dim,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.out_fc = nn.Linear(rnn_dim, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
# x: [batch_size, node_num, seq_len, input_dim]
|
||||
encode = self.encoder(x)
|
||||
out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class KRNN(Model):
|
||||
"""KRNN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim=6,
|
||||
cnn_dim=64,
|
||||
cnn_kernel_size=3,
|
||||
rnn_dim=64,
|
||||
rnn_dups=3,
|
||||
rnn_layers=2,
|
||||
dropout=0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("KRNN")
|
||||
self.logger.info("KRNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.fea_dim = fea_dim
|
||||
self.cnn_dim = cnn_dim
|
||||
self.cnn_kernel_size = cnn_kernel_size
|
||||
self.rnn_dim = rnn_dim
|
||||
self.rnn_dups = rnn_dups
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"KRNN parameters setting:"
|
||||
"\nfea_dim : {}"
|
||||
"\ncnn_dim : {}"
|
||||
"\ncnn_kernel_size : {}"
|
||||
"\nrnn_dim : {}"
|
||||
"\nrnn_dups : {}"
|
||||
"\nrnn_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size: {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
fea_dim,
|
||||
cnn_dim,
|
||||
cnn_kernel_size,
|
||||
rnn_dim,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.krnn_model = KRNNModel(
|
||||
fea_dim=self.fea_dim,
|
||||
cnn_dim=self.cnn_dim,
|
||||
cnn_kernel_size=self.cnn_kernel_size,
|
||||
rnn_dim=self.rnn_dim,
|
||||
rnn_dups=self.rnn_dups,
|
||||
rnn_layers=self.rnn_layers,
|
||||
dropout=self.dropout,
|
||||
device=self.device,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.krnn_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.krnn_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.krnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.krnn_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.krnn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.krnn_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.krnn_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.krnn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.krnn_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.krnn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.krnn_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.krnn_model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
@@ -1,381 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .pytorch_krnn import CNNKRNNEncoder
|
||||
|
||||
|
||||
class SandwichModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim,
|
||||
cnn_dim_1,
|
||||
cnn_dim_2,
|
||||
cnn_kernel_size,
|
||||
rnn_dim_1,
|
||||
rnn_dim_2,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
device,
|
||||
**params
|
||||
):
|
||||
"""Build a Sandwich model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea_dim : int
|
||||
The feature dimension
|
||||
cnn_dim_1 : int
|
||||
The hidden dimension of the first CNN
|
||||
cnn_dim_2 : int
|
||||
The hidden dimension of the second CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_dim_1 : int
|
||||
The hidden dimension of the first KRNN
|
||||
rnn_dim_2 : int
|
||||
The hidden dimension of the second KRNN
|
||||
rnn_dups : int
|
||||
The number of parallel duplicates
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.first_encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=fea_dim,
|
||||
cnn_output_dim=cnn_dim_1,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim_1,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.second_encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=rnn_dim_1,
|
||||
cnn_output_dim=cnn_dim_2,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim_2,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.out_fc = nn.Linear(rnn_dim_2, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
# x: [batch_size, node_num, seq_len, input_dim]
|
||||
encode = self.first_encoder(x)
|
||||
encode = self.second_encoder(encode)
|
||||
out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Sandwich(Model):
|
||||
"""Sandwich Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim=6,
|
||||
cnn_dim_1=64,
|
||||
cnn_dim_2=32,
|
||||
cnn_kernel_size=3,
|
||||
rnn_dim_1=16,
|
||||
rnn_dim_2=8,
|
||||
rnn_dups=3,
|
||||
rnn_layers=2,
|
||||
dropout=0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Sandwich")
|
||||
self.logger.info("Sandwich pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.fea_dim = fea_dim
|
||||
self.cnn_dim_1 = cnn_dim_1
|
||||
self.cnn_dim_2 = cnn_dim_2
|
||||
self.cnn_kernel_size = cnn_kernel_size
|
||||
self.rnn_dim_1 = rnn_dim_1
|
||||
self.rnn_dim_2 = rnn_dim_2
|
||||
self.rnn_dups = rnn_dups
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"Sandwich parameters setting:"
|
||||
"\nfea_dim : {}"
|
||||
"\ncnn_dim_1 : {}"
|
||||
"\ncnn_dim_2 : {}"
|
||||
"\ncnn_kernel_size : {}"
|
||||
"\nrnn_dim_1 : {}"
|
||||
"\nrnn_dim_2 : {}"
|
||||
"\nrnn_dups : {}"
|
||||
"\nrnn_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size: {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
fea_dim,
|
||||
cnn_dim_1,
|
||||
cnn_dim_2,
|
||||
cnn_kernel_size,
|
||||
rnn_dim_1,
|
||||
rnn_dim_2,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.sandwich_model = SandwichModel(
|
||||
fea_dim=self.fea_dim,
|
||||
cnn_dim_1=self.cnn_dim_1,
|
||||
cnn_dim_2=self.cnn_dim_2,
|
||||
cnn_kernel_size=self.cnn_kernel_size,
|
||||
rnn_dim_1=self.rnn_dim_1,
|
||||
rnn_dim_2=self.rnn_dim_2,
|
||||
rnn_dups=self.rnn_dups,
|
||||
rnn_layers=self.rnn_layers,
|
||||
dropout=self.dropout,
|
||||
device=self.device,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sandwich_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.sandwich_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.sandwich_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.sandwich_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sandwich_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.sandwich_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.sandwich_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sandwich_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.sandwich_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sandwich_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sandwich_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.sandwich_model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
@@ -168,8 +168,7 @@ class TCN(Model):
|
||||
self.TCN_model.train()
|
||||
|
||||
for data in data_loader:
|
||||
data = torch.transpose(data, 1, 2)
|
||||
feature = data[:, 0:-1, :].to(self.device)
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.TCN_model(feature.float())
|
||||
@@ -188,8 +187,8 @@ class TCN(Model):
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
data = torch.transpose(data, 1, 2)
|
||||
feature = data[:, 0:-1, :].to(self.device)
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
|
||||
@@ -16,13 +16,12 @@ import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
@@ -124,105 +123,13 @@ def _generate_report(
|
||||
return report
|
||||
|
||||
|
||||
def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
A new simulator will be created and used for every single-day order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
reports = []
|
||||
decisions = []
|
||||
for _, row in orders.iterrows():
|
||||
date = pd.Timestamp(row["datetime"])
|
||||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(row["direction"]),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
_report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: _report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
time_range: Tuple[str, str],
|
||||
exchange_config: dict,
|
||||
strategy_config: dict,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
data_granularity: str = "1min",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
@@ -250,44 +157,42 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
strategy_config = {
|
||||
top_strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"file": orders,
|
||||
"trade_range": TradeRangeByTime(
|
||||
pd.Timestamp(backtest_config["start_time"]).time(),
|
||||
pd.Timestamp(backtest_config["end_time"]).time(),
|
||||
pd.Timestamp(time_range[0]).time(),
|
||||
pd.Timestamp(time_range[1]).time(),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=strategy_config,
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
data_granularity=data_granularity,
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
exchange_config = {
|
||||
**exchange_config,
|
||||
**{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
"freq": data_granularity,
|
||||
},
|
||||
}
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
@@ -295,7 +200,7 @@ def single_with_collect_data_loop(
|
||||
)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
|
||||
|
||||
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
|
||||
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
|
||||
@@ -315,46 +220,54 @@ def single_with_collect_data_loop(
|
||||
|
||||
|
||||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config.pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
init_qlib(backtest_config["simulator"]["qlib"])
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
split="stock",
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
|
||||
single = single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
|
||||
for task_config in backtest_config["tasks"]:
|
||||
order_df = read_order_file(task_config["order_file"])
|
||||
exchange_config = task_config["exchange"]
|
||||
cash_limit = exchange_config.pop("cash_limit")
|
||||
generate_report = backtest_config["runtime"]["generate_report"]
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
#
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
time_range=task_config["time_range"],
|
||||
exchange_config=task_config["exchange"],
|
||||
strategy_config=backtest_config["strategies"],
|
||||
split="stock",
|
||||
data_granularity=task_config["data_granularity"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
|
||||
output_path = Path(backtest_config["output_dir"])
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
return res
|
||||
|
||||
#
|
||||
output_path = Path(task_config["output_dir"])
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
# return res # TODO
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -362,6 +275,7 @@ if __name__ == "__main__":
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
@@ -374,9 +288,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = get_backtest_config_fromfile(args.config_path)
|
||||
if args.n_jobs is not None:
|
||||
config["concurrency"] = args.n_jobs
|
||||
|
||||
config_parser = BacktestConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
if args.n_jobs is not None: # Overwrite concurrency
|
||||
config["runtime"]["concurrency"] = args.n_jobs
|
||||
|
||||
backtest(
|
||||
backtest_config=config,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -30,7 +31,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def parse_backtest_config(path: str) -> dict:
|
||||
def load_config(path: str) -> dict:
|
||||
abs_path = os.path.abspath(path)
|
||||
check_file_exist(abs_path)
|
||||
|
||||
@@ -65,43 +66,154 @@ def parse_backtest_config(path: str) -> dict:
|
||||
base_file_name = [base_file_name]
|
||||
|
||||
for f in base_file_name:
|
||||
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
config = merge_a_into_b(a=config, b=base_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _convert_all_list_to_tuple(config: dict) -> dict:
|
||||
for k, v in config.items():
|
||||
if isinstance(v, list):
|
||||
config[k] = tuple(v)
|
||||
elif isinstance(v, dict):
|
||||
config[k] = _convert_all_list_to_tuple(v)
|
||||
return config
|
||||
class BacktestConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
self._simulator_config = self._parse_simulator()
|
||||
self._exchange_config = self._simulator_config.pop("exchange")
|
||||
config = {
|
||||
"strategies": self.raw_config["strategies"],
|
||||
"runtime": self.raw_config["runtime"],
|
||||
"tasks": self._parse_tasks(),
|
||||
"simulator": self._simulator_config,
|
||||
}
|
||||
return config
|
||||
|
||||
def _parse_tasks(self) -> dict:
|
||||
task_config = []
|
||||
for task in self.raw_config["tasks"]:
|
||||
if "output_dir" not in task:
|
||||
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
|
||||
if "exchange" not in task:
|
||||
task["exchange"] = copy.deepcopy(self._exchange_config)
|
||||
else:
|
||||
task["exchange"] = self._complete_exchange_config(task["exchange"])
|
||||
task_config.append(task)
|
||||
|
||||
return task_config
|
||||
|
||||
def _complete_exchange_config(self, exchange_config: dict) -> dict:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
|
||||
return exchange_config
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
|
||||
return {
|
||||
"qlib": config["qlib"],
|
||||
"exchange": self._complete_exchange_config(config["exchange"]),
|
||||
}
|
||||
|
||||
|
||||
def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = parse_backtest_config(path)
|
||||
class TrainingConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
def parse(self) -> dict:
|
||||
return {
|
||||
"general": self._parse_general(),
|
||||
"policy": self.raw_config["policy"],
|
||||
"interpreter": self.raw_config["interpreter"],
|
||||
"runtime": self._parse_runtime(),
|
||||
"training": self._parse_training(),
|
||||
"simulator": self._parse_simulator(),
|
||||
}
|
||||
|
||||
backtest_config_default = {
|
||||
"debug_single_stock": None,
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs_backtest/",
|
||||
"generate_report": False,
|
||||
"data_granularity": "1min",
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
def _parse_general(self) -> dict:
|
||||
default = {
|
||||
"freq": "1min",
|
||||
"extra_module_paths": [],
|
||||
}
|
||||
return {**default, **self.raw_config["general"]}
|
||||
|
||||
return backtest_config
|
||||
def _parse_runtime(self) -> dict:
|
||||
default = {
|
||||
"seed": None,
|
||||
"use_cuda": False,
|
||||
"concurrency": 1,
|
||||
"parallel_mode": "dummy",
|
||||
}
|
||||
return {**default, **self.raw_config["runtime"]}
|
||||
|
||||
def _parse_training(self) -> dict:
|
||||
default = {
|
||||
"max_epoch": 100,
|
||||
"repeat_per_collect": 2,
|
||||
"earlystop_patience": float("inf"),
|
||||
"episode_per_collect": 10000,
|
||||
"batch_size": 256,
|
||||
"val_every_n_epoch": None,
|
||||
"checkpoint_path": "./outputs",
|
||||
"checkpoint_every_n_iters": 10,
|
||||
}
|
||||
|
||||
config = self.raw_config["training"]
|
||||
assert "order_dir" in config
|
||||
|
||||
return {**default, **config}
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
sim_type = config["type"]
|
||||
assert sim_type in ("simple", "full")
|
||||
|
||||
if sim_type == "simple":
|
||||
return {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"feature_columns_today": config["data"]["feature_columns_today"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"time_per_step": config["time_per_step"],
|
||||
"vol_limit": config["vol_limit"],
|
||||
}
|
||||
else:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **config["exchange"]}
|
||||
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
|
||||
|
||||
ret_config = {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"qlib": {
|
||||
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
|
||||
},
|
||||
"exchange": exchange_config,
|
||||
}
|
||||
|
||||
return ret_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(parser.parse())
|
||||
|
||||
362
qlib/rl/contrib/train.py
Normal file
362
qlib/rl/contrib/train.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, cast, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl import Simulator
|
||||
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(freq: int) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{freq}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
def _freq_str_to_int(freq: str) -> int:
|
||||
if freq.endswith("min"):
|
||||
return int(freq.replace("min", ""))
|
||||
elif freq.endswith("hour"):
|
||||
return int(freq.replace("hour", "") * 60)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized freq string: {freq}")
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_df: pd.DataFrame,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = order_df
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
|
||||
df = df.copy()
|
||||
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
|
||||
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
|
||||
return dfs
|
||||
|
||||
|
||||
def _get_simulator_factory(
|
||||
sim_type: str,
|
||||
data_dir: Path,
|
||||
freq_min: int,
|
||||
simulator_config: dict,
|
||||
) -> Callable[[Order], Simulator]:
|
||||
if sim_type == "simple":
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
simulator = SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_dir,
|
||||
feature_columns_today=simulator_config["data"]["feature_columns_today"],
|
||||
data_granularity=freq_min,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_simple
|
||||
elif sim_type == "full":
|
||||
init_qlib(simulator_config["qlib"])
|
||||
executor_config = get_executor_config(freq_min)
|
||||
exchange_config = simulator_config["exchange"]
|
||||
|
||||
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_full
|
||||
else:
|
||||
raise ValueError(f"Unknown simulator type: {sim_type}")
|
||||
|
||||
|
||||
def train_and_test(
|
||||
freq: str,
|
||||
concurrency: int,
|
||||
parallel_mode: str,
|
||||
training_config: dict,
|
||||
simulator_config: dict,
|
||||
policy: BasePolicy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
freq_min: int = _freq_str_to_int(freq)
|
||||
order_root_path = Path(training_config["order_dir"])
|
||||
feature_root_dir = simulator_config["data"]["feature_root_dir"]
|
||||
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
|
||||
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
|
||||
|
||||
_simulator_factory = _get_simulator_factory(
|
||||
sim_type=simulator_config["type"],
|
||||
data_dir=feature_root_dir,
|
||||
freq_min=freq_min,
|
||||
simulator_config=simulator_config,
|
||||
)
|
||||
|
||||
# Load orders
|
||||
load_data_tags = []
|
||||
orders_by_tag = {}
|
||||
if run_training:
|
||||
load_data_tags += ["train", "valid"]
|
||||
if run_backtest:
|
||||
load_data_tags += ["test"]
|
||||
for tag in load_data_tags:
|
||||
order_df = _read_orders(order_root_path / tag).reset_index()
|
||||
dfs = _split_order_df_by_instrument(order_df, concurrency)
|
||||
datasets = [
|
||||
LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_df=df,
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
|
||||
)
|
||||
for df in dfs
|
||||
]
|
||||
orders_by_tag[tag] = datasets
|
||||
|
||||
if run_training:
|
||||
callbacks: List[Callback] = [
|
||||
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
|
||||
Checkpoint(
|
||||
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=training_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
EarlyStopping(
|
||||
patience=training_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
),
|
||||
]
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
|
||||
trainer_kwargs={
|
||||
"max_iters": training_config["max_epoch"],
|
||||
"finite_env_type": parallel_mode,
|
||||
"concurrency": concurrency,
|
||||
"val_every_n_iters": training_config["val_every_n_epoch"],
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": training_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": training_config["batch_size"],
|
||||
"repeat": training_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(training_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=parallel_mode, # type: ignore[arg-type]
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
seed = config["runtime"]["seed"]
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
for extra_module_path in config["general"]["extra_module_paths"]:
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
|
||||
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
# Create torch network
|
||||
if "network" in config["policy"]:
|
||||
network_config = config["policy"]["network"]
|
||||
network_config["kwargs"] = {
|
||||
**network_config.get("kwargs", {}),
|
||||
**{"obs_space": state_interpreter.observation_space},
|
||||
}
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
|
||||
|
||||
# Create policy
|
||||
policy_config = config["policy"]["policy"]
|
||||
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
|
||||
policy: BasePolicy = init_instance_by_config(policy_config)
|
||||
|
||||
use_cuda = config["runtime"]["use_cuda"]
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
freq=config["general"]["freq"],
|
||||
concurrency=config["runtime"]["concurrency"],
|
||||
parallel_mode=config["runtime"]["parallel_mode"],
|
||||
training_config=config["training"],
|
||||
simulator_config=config["simulator"],
|
||||
policy=policy,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_parser = TrainingConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -1,268 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
simulator_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
feature_columns_today=data_config["source"]["feature_columns_today"],
|
||||
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
|
||||
data_granularity=data_granularity,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
for extra_module_path in config["env"].get("extra_module_paths", []):
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
simulator_config=config["simulator"],
|
||||
data_config=config["data"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.data.dataset import DatasetH
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -140,6 +141,16 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_handler_pickle(path: str) -> DatasetH:
|
||||
with open(path, "rb") as fstream:
|
||||
obj = pickle.load(fstream)
|
||||
return obj
|
||||
|
||||
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
@@ -151,7 +162,6 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -161,31 +171,17 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
dataset = _load_handler_pickle(path)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
@@ -193,10 +189,14 @@ def load_handler_intraday_processed_data(
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
@@ -229,5 +229,4 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import List, Sequence, cast
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
@@ -158,6 +157,15 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_df_pickle(path: str) -> pd.DataFrame:
|
||||
df = pd.read_pickle(path)
|
||||
return df
|
||||
|
||||
|
||||
class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
|
||||
|
||||
@@ -166,36 +174,18 @@ class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool,
|
||||
) -> None:
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
if isinstance(data_dir, str):
|
||||
data_dir = Path(data_dir)
|
||||
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
|
||||
df = _load_df_pickle(str(path))
|
||||
df = df.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
self.today = df[feature_columns_today]
|
||||
self.yesterday = df[feature_columns_yesterday]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -213,25 +203,38 @@ def load_simple_intraday_backtest_data(
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_pickle_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
return PickleIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(self, data_dir: Path) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._backtest = backtest
|
||||
self._feature_columns_today = feature_columns_today
|
||||
self._feature_columns_yesterday = feature_columns_yesterday
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
@@ -244,8 +247,9 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
feature_dim=feature_dim,
|
||||
time_index=time_index,
|
||||
feature_columns_today=self._feature_columns_today,
|
||||
feature_columns_yesterday=self._feature_columns_yesterday,
|
||||
backtest=self._backtest,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, List, Optional
|
||||
import cachetools
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
@@ -16,6 +17,18 @@ from .state import SAOEState
|
||||
from .strategy import SAOEStateAdapter, SAOEStrategy
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda order, _: order.stock_id,
|
||||
)
|
||||
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
|
||||
exchange_kwargs = {
|
||||
**exchange_config,
|
||||
"codes": [order.stock_id],
|
||||
}
|
||||
return get_exchange(**exchange_kwargs)
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
@@ -76,7 +89,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
exchange_kwargs=_create_exchange(order, exchange_config),
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
|
||||
@@ -90,6 +103,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
show_progress=False,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ from pathlib import Path
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
@@ -42,8 +43,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
feature_columns_yesterday
|
||||
Columns of yesterday's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
@@ -80,7 +79,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
feature_columns_yesterday: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
vol_threshold: Optional[float] = None,
|
||||
@@ -92,7 +90,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.order = order
|
||||
self.data_dir = data_dir
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
@@ -122,14 +119,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
|
||||
def get_backtest_data(self) -> BaseIntradayBacktestData:
|
||||
try:
|
||||
data = load_handler_intraday_processed_data(
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self.data_dir,
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=self.feature_columns_yesterday,
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=False,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
|
||||
@@ -451,6 +451,7 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: dict | torch.nn.Module | None = None,
|
||||
immediate_addition: bool = False,
|
||||
outer_trade_decision: BaseTradeDecision | None = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
@@ -501,9 +502,12 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
self.immediate_addition = immediate_addition
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
self.trade_amount_planned = collections.defaultdict(float)
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
@@ -539,9 +543,15 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
|
||||
order = cast(Order, decision)
|
||||
if self.immediate_addition:
|
||||
self.trade_amount_planned[order.stock_id] += exec_vol
|
||||
amount_planned = self.trade_amount_planned[order.stock_id]
|
||||
amount_finished = order.amount - state.position
|
||||
exec_vol = min(state.position, amount_planned - amount_finished)
|
||||
|
||||
if exec_vol != 0:
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWithDetails(
|
||||
|
||||
@@ -20,7 +20,7 @@ def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: Dict[str, Any],
|
||||
@@ -39,7 +39,9 @@ def train(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
@@ -67,7 +69,7 @@ def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
@@ -87,7 +89,9 @@ def backtest(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
|
||||
@@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from contextlib import AbstractContextManager, ExitStack, contextmanager
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
@@ -206,45 +207,50 @@ class Trainer:
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
_logger.info(msg)
|
||||
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
|
||||
vessel.val_seed_iterators()
|
||||
) as valid_iterators:
|
||||
train_vector_env = self.venv_from_iterator(train_iterators)
|
||||
valid_vector_env = self.venv_from_iterator(valid_iterators)
|
||||
|
||||
self.initialize_iter()
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
print(msg)
|
||||
_logger.info(msg)
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
self.initialize_iter()
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
self.vessel.train(train_vector_env)
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
self.vessel.validate(valid_vector_env)
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
@@ -265,16 +271,16 @@ class Trainer:
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
with _wrap_context(vessel.test_seed_iterators()) as iterators:
|
||||
vector_env = self.venv_from_iterator(iterators)
|
||||
self.vessel.test(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory():
|
||||
def env_factory(iterator):
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
@@ -300,7 +306,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
env_factory,
|
||||
[partial(env_factory, iterator=it) for it in iterators],
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
@@ -334,8 +340,11 @@ class Trainer:
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(e) for e in obj]
|
||||
stack.pop_all().close()
|
||||
elif isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
@@ -49,19 +49,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
@@ -120,9 +124,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
train_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
val_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
test_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
@@ -132,34 +136,49 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.train_initial_states = None if train_initial_states is None else train_initial_states
|
||||
self.val_initial_states = None if val_initial_states is None else val_initial_states
|
||||
self.test_initial_states = None if test_initial_states is None else test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
|
||||
train_initial_states = [
|
||||
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().train_seed_iterators()
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
|
||||
val_initial_states = [
|
||||
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().val_seed_iterators()
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
|
||||
test_initial_states = [
|
||||
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().test_seed_iterators()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
|
||||
@@ -258,6 +258,46 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step2(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = {}
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
tmp = super().step(valid_act, request_id)
|
||||
|
||||
for obs_next, rew, done, info in zip(*tmp):
|
||||
obs_next = self._postproc_env_obs(obs_next)
|
||||
result[info["env_id"]] = [obs_next, rew, done, info]
|
||||
|
||||
# logging
|
||||
for i, r in result.items():
|
||||
if i in self._alive_env_ids and r[0] is not None:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
||||
for _, reward, __, info in result.values():
|
||||
self._set_default_info(info)
|
||||
self._set_default_rew(reward)
|
||||
for r in result.values():
|
||||
if r[0] is None:
|
||||
r[0] = self._get_default_obs()
|
||||
if r[1] is None:
|
||||
r[1] = self._get_default_rew()
|
||||
if r[3] is None:
|
||||
r[3] = self._get_default_info()
|
||||
|
||||
ret = list(map(np.stack, zip(*result.values())))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
@@ -311,7 +351,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
||||
|
||||
|
||||
def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_factories: List[Callable[..., gym.Env]],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
@@ -334,9 +374,10 @@ def vectorize_env(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_factory
|
||||
Callable to instantiate one single ``gym.Env``.
|
||||
All concurrent workers will have the same ``env_factory``.
|
||||
env_factories
|
||||
Callables to instantiate one single ``gym.Env``.
|
||||
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
|
||||
the same env_factory. Otherwise, each worker will have its own env_factory.
|
||||
env_type
|
||||
dummy or subproc or shmem. Corresponding to
|
||||
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
@@ -358,6 +399,8 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
assert len(env_factories) in (1, concurrency)
|
||||
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
@@ -366,4 +409,7 @@ def vectorize_env(
|
||||
|
||||
finite_env_cls = env_type_cls_mapping[env_type]
|
||||
|
||||
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|
||||
if len(env_factories) == 1:
|
||||
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
|
||||
else:
|
||||
return finite_env_cls(logger, env_factories)
|
||||
|
||||
30
qlib/rl/utils/profiling.py
Normal file
30
qlib/rl/utils/profiling.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Generator
|
||||
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
|
||||
s = time.perf_counter()
|
||||
yield
|
||||
e = time.perf_counter()
|
||||
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
|
||||
|
||||
if out_path is not None:
|
||||
with open(out_path, "a") as fstream:
|
||||
fstream.write(msg + "\n")
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
|
||||
def lprofile(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
lp = LineProfiler()
|
||||
lpw = lp(func)
|
||||
res = lpw(*args, **kwargs)
|
||||
lp.print_stats()
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import qlib
|
||||
@@ -12,15 +11,13 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from cryptography.fernet import Fernet
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v2"
|
||||
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
|
||||
# "?" is not included in the token.
|
||||
TOKEN = "gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
|
||||
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -32,44 +29,24 @@ class GetData:
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def merge_remote_url(self, file_name: str):
|
||||
fernet = Fernet(self.KEY)
|
||||
token = fernet.decrypt(self.TOKEN).decode()
|
||||
return f"{self.REMOTE_URL}/{file_name}?{token}"
|
||||
def normalize_dataset_version(self, dataset_version: str = None):
|
||||
if dataset_version is None:
|
||||
dataset_version = self.DATASET_VERSION
|
||||
return dataset_version
|
||||
|
||||
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
|
||||
"""
|
||||
Download the specified file to the target folder.
|
||||
def merge_remote_url(self, file_name: str, dataset_version: str = None):
|
||||
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
file_name: str
|
||||
dataset name, needs to endwith .zip, value from [rl_data.zip, csv_data_cn.zip, ...]
|
||||
may contain folder names, for example: v2/qlib_data_simple_cn_1d_latest.zip
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get rl data
|
||||
python get_data.py download_data --file_name rl_data.zip --target_dir ~/.qlib/qlib_data/rl_data
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/rl_data.zip?{token}
|
||||
|
||||
# get cn csv data
|
||||
python get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/csv_data_cn.zip?{token}
|
||||
-------
|
||||
|
||||
"""
|
||||
def _download_data(
|
||||
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
|
||||
):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
# saved file name
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + os.path.basename(file_name)
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name)
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
@@ -79,7 +56,7 @@ class GetData:
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
@@ -90,8 +67,8 @@ class GetData:
|
||||
if self.delete_zip_file:
|
||||
target_path.unlink()
|
||||
|
||||
def check_dataset(self, file_name: str):
|
||||
url = self.merge_remote_url(file_name)
|
||||
def check_dataset(self, file_name: str, dataset_version: str = None):
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
status = True
|
||||
if resp.status_code == 404:
|
||||
@@ -163,11 +140,9 @@ class GetData:
|
||||
---------
|
||||
# get 1d data
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1d_latest.zip?{token}
|
||||
|
||||
# get 1min data
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1min_latest.zip?{token}
|
||||
-------
|
||||
|
||||
"""
|
||||
@@ -180,12 +155,29 @@ class GetData:
|
||||
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name_with_version(qlib_version, dataset_version):
|
||||
dataset_version = "v2" if dataset_version is None else dataset_version
|
||||
file_name_with_version = f"{dataset_version}/{name}_{region.lower()}_{interval.lower()}_{qlib_version}.zip"
|
||||
return file_name_with_version
|
||||
def _get_file_name(v):
|
||||
return self.QLIB_DATA_NAME.format(
|
||||
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
|
||||
)
|
||||
|
||||
file_name = _get_file_name_with_version(qlib_version, dataset_version=version)
|
||||
if not self.check_dataset(file_name):
|
||||
file_name = _get_file_name_with_version("latest", dataset_version=version)
|
||||
self.download_data(file_name.lower(), target_dir, delete_old)
|
||||
file_name = _get_file_name(qlib_version)
|
||||
if not self.check_dataset(file_name, version):
|
||||
file_name = _get_file_name("latest")
|
||||
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -80,7 +80,6 @@ REQUIRED = [
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
"cryptography",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestDumpData(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestGetData(unittest.TestCase):
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
def test_1_csv_data(self):
|
||||
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
self.assertEqual(len(stock_name), 85, "get csv data failed")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user