diff --git a/examples/benchmarks/TabNet/pretrain/best.model b/examples/benchmarks/TabNet/pretrain/best.model index 9a3939232..a85cbe392 100644 Binary files a/examples/benchmarks/TabNet/pretrain/best.model and b/examples/benchmarks/TabNet/pretrain/best.model differ diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml index e1bc0f69e..4e9f0e7e9 100644 --- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml +++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml @@ -41,7 +41,7 @@ port_analysis_config: &port_analysis_config min_cost: 5 task: model: - class: TabNet_Model + class: TabnetModel module_path: qlib.contrib.model.pytorch_tabnet kwargs: pretrain: True diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 0be28f6dd..ef1c8e2a8 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -28,7 +28,7 @@ from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP -class TabNet_Model(Model): +class TabnetModel(Model): def __init__( self, d_feat=158,