diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index a41eeabbb..1623e7e1c 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -259,7 +259,7 @@ class DNNModelPytorch(Model): loss = torch.mul(sqr_loss, w).mean() return loss elif loss_type == "binary": - loss = nn.BCELoss() + loss = nn.BCELoss(weight=w) return loss(pred, target) else: raise NotImplementedError("loss {} is not supported!".format(loss_type))