diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index 8dadefb68..4f87e5f1e 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -145,7 +145,7 @@ class TCTS(Model): init_fore_model = copy.deepcopy(self.fore_model) for p in init_fore_model.parameters(): - p.init_fore_model = False + p.requires_grad = False self.fore_model.train() self.weight_model.train()