From 73ec0f40036aa12a9d26e46984ae5bb9ad8443f5 Mon Sep 17 00:00:00 2001 From: raikiriww Date: Wed, 19 Jun 2024 17:31:47 +0800 Subject: [PATCH] Add "mse" metric option to ALSTM.metric_fn (#1810) --- qlib/contrib/model/pytorch_alstm_ts.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 008d78940..3fb7cb9e1 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -160,6 +160,10 @@ class ALSTM(Model): if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) + elif self.metric == "mse": + mask = ~torch.isnan(label) + weight = torch.ones_like(label) + return -self.mse(pred[mask], label[mask], weight[mask]) raise ValueError("unknown metric `%s`" % self.metric)