mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add arg weight_decay
This commit is contained in:
@@ -67,6 +67,7 @@ task:
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -62,6 +62,7 @@ class DNNModelPytorch(Model):
|
||||
loss="mse",
|
||||
GPU="0",
|
||||
seed=None,
|
||||
weight_decay=0.0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
@@ -82,6 +83,7 @@ class DNNModelPytorch(Model):
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
self.logger.info(
|
||||
"DNN parameters setting:"
|
||||
@@ -98,7 +100,8 @@ class DNNModelPytorch(Model):
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}".format(
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
lr,
|
||||
max_steps,
|
||||
@@ -113,6 +116,7 @@ class DNNModelPytorch(Model):
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
weight_decay
|
||||
)
|
||||
)
|
||||
|
||||
@@ -126,9 +130,9 @@ class DNNModelPytorch(Model):
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user