mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add pretrain-mode to gats
This commit is contained in:
@@ -37,9 +37,10 @@ task:
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: IC
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: GRU
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
|
||||
BIN
examples/benchmarks/GRU/model_gru_csi300.pkl
Normal file
BIN
examples/benchmarks/GRU/model_gru_csi300.pkl
Normal file
Binary file not shown.
BIN
examples/benchmarks/LSTM/model_lstm_csi300.pkl
Normal file
BIN
examples/benchmarks/LSTM/model_lstm_csi300.pkl
Normal file
Binary file not shown.
@@ -70,9 +70,10 @@ if __name__ == "__main__":
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"base_model": "GRU",
|
||||
"base_model": "LSTM",
|
||||
"with_pretrain": True,
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
|
||||
@@ -55,6 +55,7 @@ class GAT(Model):
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
@@ -77,6 +78,7 @@ class GAT(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.visible_GPU = GPU
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
@@ -95,6 +97,7 @@ class GAT(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -110,6 +113,7 @@ class GAT(Model):
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -256,6 +260,23 @@ class GAT(Model):
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load('benchmarks/LSTM/model_lstm_csi300.pkl'))
|
||||
elif self.base_model == "GRU":
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load('benchmarks/GRU/model_gru_csi300.pkl'))
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
Reference in New Issue
Block a user