diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index fad52e834..aa5b22119 100755 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -117,11 +117,7 @@ class GAT(Model): seed, ) ) - - if loss not in {"mse", "binary"}: - raise NotImplementedError("loss {} is not supported!".format(loss)) - self._scorer = mean_squared_error if loss == "mse" else roc_auc_score - + self.GAT_model = GATModel( d_feat=self.d_feat, hidden_size=self.hidden_size, @@ -211,7 +207,6 @@ class GAT(Model): losses = [] indices = np.arange(len(x_values)) - np.random.shuffle(indices) for i in range(len(indices))[:: self.batch_size]: