mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Update setting for model training.
This commit is contained in:
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
|
||||
@@ -46,7 +46,7 @@ class GRU(Model):
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="IC",
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
@@ -140,21 +140,17 @@ class GRU(Model):
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "IC":
|
||||
return self.cal_ic(pred[mask], label[mask])
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def cal_ic(self, pred, label):
|
||||
return torch.mean(pred * label)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values) * 100
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.gru_model.train()
|
||||
|
||||
@@ -193,7 +189,6 @@ class GRU(Model):
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user