diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 868ab1513..780dc4b91 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -267,7 +267,7 @@ class DNNModelPytorch(Model): loss = torch.mul(sqr_loss, w).mean() return loss elif loss_type == "binary": - loss = nn.BCELoss(weight=w) + loss = nn.BCEWithLogitsLoss(weight=w) return loss(pred, target) else: raise NotImplementedError("loss {} is not supported!".format(loss_type)) @@ -334,16 +334,8 @@ class Net(nn.Module): dnn_layers.append(seq) drop_input = nn.Dropout(0.05) dnn_layers.append(drop_input) - if loss == "mse": - fc = nn.Linear(hidden_units, output_dim) - dnn_layers.append(fc) - - elif loss == "binary": - fc = nn.Linear(hidden_units, output_dim) - sigmoid = nn.Sigmoid() - dnn_layers.append(nn.Sequential(fc, sigmoid)) - else: - raise NotImplementedError("loss {} is not supported!".format(loss)) + fc = nn.Linear(hidden_units, output_dim) + dnn_layers.append(fc) # optimizer self.dnn_layers = nn.ModuleList(dnn_layers) self._weight_init()