From f42661f2d4250ffc26f0f4cedd07d94097ba6b4a Mon Sep 17 00:00:00 2001 From: Hong Zhang Date: Thu, 26 Nov 2020 13:55:12 +0800 Subject: [PATCH] gat2 --- qlib/contrib/model/pytorch_gats.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index fad52e834..aa5b22119 100755 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -117,11 +117,7 @@ class GAT(Model): seed, ) ) - - if loss not in {"mse", "binary"}: - raise NotImplementedError("loss {} is not supported!".format(loss)) - self._scorer = mean_squared_error if loss == "mse" else roc_auc_score - + self.GAT_model = GATModel( d_feat=self.d_feat, hidden_size=self.hidden_size, @@ -211,7 +207,6 @@ class GAT(Model): losses = [] indices = np.arange(len(x_values)) - np.random.shuffle(indices) for i in range(len(indices))[:: self.batch_size]: