diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index c22e48204..9e5aa3e28 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -61,7 +61,7 @@ class GATs(Model): with_pretrain=True, model_path=None, optimizer="adam", - GPU="0", + GPU=0, seed=None, **kwargs ):