mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
Add random seed.
This commit is contained in:
@@ -113,6 +113,9 @@ class ALSTM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.ALSTM_model = ALSTMModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@@ -113,6 +113,9 @@ class GRU(Model):
|
||||
)
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.gru_model = GRUModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@@ -113,6 +113,9 @@ class LSTM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.lstm_model = LSTMModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@@ -61,6 +61,7 @@ class DNNModelPytorch(Model):
|
||||
optimizer="gd",
|
||||
loss="mse",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
@@ -80,6 +81,7 @@ class DNNModelPytorch(Model):
|
||||
self.loss_type = loss
|
||||
self.visible_GPU = GPU
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"DNN parameters setting:"
|
||||
@@ -94,6 +96,7 @@ class DNNModelPytorch(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}".format(
|
||||
layers,
|
||||
@@ -107,11 +110,15 @@ class DNNModelPytorch(Model):
|
||||
optimizer,
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
)
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
@@ -282,6 +282,9 @@ class SFM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.sfm_model = SFM_Model(
|
||||
d_feat=self.d_feat,
|
||||
output_dim=self.output_dim,
|
||||
|
||||
Reference in New Issue
Block a user