From 5da5cf51756ebd0b5cd4909c7e058a8efe0071ea Mon Sep 17 00:00:00 2001 From: aurora5161 <565056427@qq.com> Date: Sun, 6 Feb 2022 22:34:00 +0800 Subject: [PATCH] add weight param (#907) --- qlib/contrib/model/pytorch_lstm_ts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index d7705981a..70b8b0ce8 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -138,7 +138,7 @@ class LSTM(Model): loss = weight * (pred - label) ** 2 return torch.mean(loss) - def loss_fn(self, pred, label): + def loss_fn(self, pred, label, weight): mask = ~torch.isnan(label) if weight is None: @@ -154,7 +154,7 @@ class LSTM(Model): mask = torch.isfinite(label) if self.metric in ("", "loss"): - return -self.loss_fn(pred[mask], label[mask]) + return -self.loss_fn(pred[mask], label[mask], weight = None) raise ValueError("unknown metric `%s`" % self.metric)