1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

Update setting for model training.

This commit is contained in:
lwwang1995
2020-11-26 21:34:16 +08:00
parent 6b053137fd
commit 38cfb22cba
2 changed files with 3 additions and 8 deletions

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
"lr": 1e-3,
"early_stop": 20,
"batch_size": 800,
"metric": "IC",
"metric": "loss",
"loss": "mse",
"seed": 0,
"GPU": 0,

View File

@@ -46,7 +46,7 @@ class GRU(Model):
dropout=0.0,
n_epochs=200,
lr=0.001,
metric="IC",
metric="",
batch_size=2000,
early_stop=20,
loss="mse",
@@ -140,21 +140,17 @@ class GRU(Model):
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "IC":
return self.cal_ic(pred[mask], label[mask])
if self.metric == "" or self.metric == "loss": # use loss
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def cal_ic(self, pred, label):
return torch.mean(pred * label)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values) * 100
y_train_values = np.squeeze(y_train.values)
self.gru_model.train()
@@ -193,7 +189,6 @@ class GRU(Model):
losses = []
indices = np.arange(len(x_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]: