1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update pytorch_nn.py

This commit is contained in:
Wendi Li
2021-01-23 13:21:46 +00:00
committed by you-n-g
parent afdf58b4fa
commit 84d77f4585

View File

@@ -259,7 +259,7 @@ class DNNModelPytorch(Model):
loss = torch.mul(sqr_loss, w).mean()
return loss
elif loss_type == "binary":
loss = nn.BCELoss()
loss = nn.BCELoss(weight=w)
return loss(pred, target)
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))