1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00
This commit is contained in:
Hong Zhang
2020-11-26 13:55:12 +08:00
parent 398f67f8d8
commit f42661f2d4

View File

@@ -117,11 +117,7 @@ class GAT(Model):
seed,
)
)
if loss not in {"mse", "binary"}:
raise NotImplementedError("loss {} is not supported!".format(loss))
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
self.GAT_model = GATModel(
d_feat=self.d_feat,
hidden_size=self.hidden_size,
@@ -211,7 +207,6 @@ class GAT(Model):
losses = []
indices = np.arange(len(x_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]: