1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Add "mse" metric option to ALSTM.metric_fn (#1810)

This commit is contained in:
raikiriww
2024-06-19 17:31:47 +08:00
committed by GitHub
parent 155c17f8ff
commit 73ec0f4003

View File

@@ -160,6 +160,10 @@ class ALSTM(Model):
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown metric `%s`" % self.metric)