mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
gat2
This commit is contained in:
@@ -117,11 +117,7 @@ class GAT(Model):
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
|
||||
self.GAT_model = GATModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -211,7 +207,6 @@ class GAT(Model):
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user