From f4674ef98c841f3022e2c4aa0e40accde0d54ce0 Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 10 Jul 2024 05:31:14 +0000 Subject: [PATCH] Add model template; --- examples/benchmarks/GeneralPtNN/README.md | 15 +++ .../GeneralPtNN/workflow_config_gru.yaml | 97 ++++++++++++++++++ .../GeneralPtNN/workflow_config_mlp.yaml | 98 +++++++++++++++++++ 3 files changed, 210 insertions(+) create mode 100644 examples/benchmarks/GeneralPtNN/README.md create mode 100755 examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml create mode 100644 examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml diff --git a/examples/benchmarks/GeneralPtNN/README.md b/examples/benchmarks/GeneralPtNN/README.md new file mode 100644 index 000000000..817b1f9c2 --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/README.md @@ -0,0 +1,15 @@ + + +# Introduction + +What is GeneralPtNN +- Fix previous design that fail to support both Time-series and tabular data +- Now you can just replace the Pytorch model structure to run a NN model. + +We provide an example to demonstrate the effectiveness of the current design. +- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158 dataset) +- `workflow_config_mlp.yaml` align with previous results [MLP](../README.md#Alpha158 dataset) + +# TODO + +We will align existing models to current design. diff --git a/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml b/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml new file mode 100755 index 000000000..a2f03a230 --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml @@ -0,0 +1,97 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW" + ] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: GRU + module_path: qlib.contrib.model.pytorch_gru_ts + kwargs: + d_feat: 20 + hidden_size: 64 + num_layers: 2 + dropout: 0.0 + n_epochs: 200 + lr: 2e-4 + early_stop: 10 + batch_size: 800 + metric: loss + loss: mse + n_jobs: 20 + GPU: 0 + dataset: + class: TSDatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + step_len: 20 + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml b/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml new file mode 100644 index 000000000..6c85546ca --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml @@ -0,0 +1,98 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "CSZFillna", + "kwargs":{"fields_group": "feature"} + } + ] + learn_processors: [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "DropnaProcessor", + "kwargs":{"fields_group": "feature"} + }, + "DropnaLabel", + { + "class": "CSZScoreNorm", + "kwargs": {"fields_group": "label"} + } + ] + process_type: "independent" + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: DNNModelPytorch + module_path: qlib.contrib.model.pytorch_nn + kwargs: + loss: mse + lr: 0.002 + optimizer: adam + max_steps: 8000 + batch_size: 8192 + GPU: 0 + weight_decay: 0.0002 + pt_model_kwargs: + input_dim: 157 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config