mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update mlp model
This commit is contained in:
@@ -65,7 +65,7 @@ task:
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
|
||||
@@ -50,7 +50,7 @@ class DNNModelPytorch(Model):
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
layers=(256, 512, 768, 512, 256, 128, 64),
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
batch_size=2000,
|
||||
@@ -126,9 +126,9 @@ class DNNModelPytorch(Model):
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr)
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr)
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=2e-4)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user