1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add pretrain-mode to gats

This commit is contained in:
meng-ustc
2020-11-25 17:46:34 +08:00
parent c14a99a735
commit cd7c81cfd0
5 changed files with 27 additions and 4 deletions

View File

@@ -37,9 +37,10 @@ task:
lr: 1e-3
early_stop: 20
batch_size: 800
metric: IC
metric: loss
loss: mse
base_model: GRU
base_model: LSTM
with_pretrain: True
seed: 0
GPU: 0
dataset:

Binary file not shown.

Binary file not shown.

View File

@@ -70,9 +70,10 @@ if __name__ == "__main__":
"lr": 1e-3,
"early_stop": 20,
"batch_size": 800,
"metric": "IC",
"metric": "loss",
"loss": "mse",
"base_model": "GRU",
"base_model": "LSTM",
"with_pretrain": True,
"seed": 0,
"GPU": 0,
},

View File

@@ -55,6 +55,7 @@ class GAT(Model):
early_stop=20,
loss="mse",
base_model="GRU",
with_pretrain=True,
optimizer="adam",
GPU="0",
seed=0,
@@ -77,6 +78,7 @@ class GAT(Model):
self.optimizer = optimizer.lower()
self.loss = loss
self.base_model = base_model
self.with_pretrain = with_pretrain
self.visible_GPU = GPU
self.use_gpu = torch.cuda.is_available()
self.seed = seed
@@ -95,6 +97,7 @@ class GAT(Model):
"\noptimizer : {}"
"\nloss_type : {}"
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
@@ -110,6 +113,7 @@ class GAT(Model):
optimizer.lower(),
loss,
base_model,
with_pretrain,
GPU,
self.use_gpu,
seed,
@@ -256,6 +260,23 @@ class GAT(Model):
evals_result["train"] = []
evals_result["valid"] = []
# load pretrained base_model
if self.with_pretrain:
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
from ...contrib.model.pytorch_lstm import LSTMModel
pretrained_model = LSTMModel()
pretrained_model.load_state_dict(torch.load('benchmarks/LSTM/model_lstm_csi300.pkl'))
elif self.base_model == "GRU":
from ...contrib.model.pytorch_gru import GRUModel
pretrained_model = GRUModel()
pretrained_model.load_state_dict(torch.load('benchmarks/GRU/model_gru_csi300.pkl'))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
# train
self.logger.info("training...")
self._fitted = True