diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index 70b8b0ce8..e703130fb 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -154,7 +154,7 @@ class LSTM(Model): mask = torch.isfinite(label) if self.metric in ("", "loss"): - return -self.loss_fn(pred[mask], label[mask], weight = None) + return -self.loss_fn(pred[mask], label[mask], weight=None) raise ValueError("unknown metric `%s`" % self.metric)