diff --git a/examples/benchmarks/GATs/workflow_config_gats.yaml b/examples/benchmarks/GATs/workflow_config_gats.yaml index 7212e0ee2..c38b4b312 100644 --- a/examples/benchmarks/GATs/workflow_config_gats.yaml +++ b/examples/benchmarks/GATs/workflow_config_gats.yaml @@ -40,19 +40,19 @@ port_analysis_config: &port_analysis_config min_cost: 5 task: model: - class: GAT_Classic - module_path: qlib.contrib.model.pytorch_gats_classic + class: GATs + module_path: qlib.contrib.model.pytorch_gats kwargs: d_feat: 6 hidden_size: 64 num_layers: 2 - dropout: 0.0 + dropout: 0.7 n_epochs: 200 - lr: 1e-3 + lr: 1e-4 early_stop: 20 metric: loss loss: mse - base_model: GRU + base_model: LSTM seed: 0 GPU: 0 dataset: