mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update pytorch_nn.py
This commit is contained in:
@@ -259,7 +259,7 @@ class DNNModelPytorch(Model):
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCELoss()
|
||||
loss = nn.BCELoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
Reference in New Issue
Block a user