1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add arg weight_decay

This commit is contained in:
bxdd
2020-12-09 00:27:54 +08:00
committed by you-n-g
parent 2873813562
commit 56e579e20f
2 changed files with 8 additions and 3 deletions

View File

@@ -67,6 +67,7 @@ task:
max_steps: 8000
batch_size: 8192
GPU: 0
weight_decay: 0.0002
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -62,6 +62,7 @@ class DNNModelPytorch(Model):
loss="mse",
GPU="0",
seed=None,
weight_decay=0.0,
**kwargs
):
# Set logger.
@@ -82,6 +83,7 @@ class DNNModelPytorch(Model):
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_GPU = torch.cuda.is_available()
self.seed = seed
self.weight_decay = weight_decay
self.logger.info(
"DNN parameters setting:"
@@ -98,7 +100,8 @@ class DNNModelPytorch(Model):
"\neval_steps : {}"
"\nseed : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}".format(
"\nuse_GPU : {}"
"\nweight_decay : {}".format(
layers,
lr,
max_steps,
@@ -113,6 +116,7 @@ class DNNModelPytorch(Model):
seed,
GPU,
self.use_GPU,
weight_decay
)
)
@@ -126,9 +130,9 @@ class DNNModelPytorch(Model):
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))