diff --git a/examples/benchmarks/GATs/worflow_config_gats.yaml b/examples/benchmarks/GATs/worflow_config_gats.yaml index 84eeff4db..37bced99d 100644 --- a/examples/benchmarks/GATs/worflow_config_gats.yaml +++ b/examples/benchmarks/GATs/worflow_config_gats.yaml @@ -37,9 +37,10 @@ task: lr: 1e-3 early_stop: 20 batch_size: 800 - metric: IC + metric: loss loss: mse - base_model: GRU + base_model: LSTM + with_pretrain: True seed: 0 GPU: 0 dataset: diff --git a/examples/benchmarks/GRU/model_gru_csi300.pkl b/examples/benchmarks/GRU/model_gru_csi300.pkl new file mode 100644 index 000000000..46347ce8c Binary files /dev/null and b/examples/benchmarks/GRU/model_gru_csi300.pkl differ diff --git a/examples/benchmarks/LSTM/model_lstm_csi300.pkl b/examples/benchmarks/LSTM/model_lstm_csi300.pkl new file mode 100644 index 000000000..ff7fee450 Binary files /dev/null and b/examples/benchmarks/LSTM/model_lstm_csi300.pkl differ diff --git a/examples/workflow_by_code_gats.py b/examples/workflow_by_code_gats.py index 6b15b77b4..3bb4edf08 100644 --- a/examples/workflow_by_code_gats.py +++ b/examples/workflow_by_code_gats.py @@ -70,9 +70,10 @@ if __name__ == "__main__": "lr": 1e-3, "early_stop": 20, "batch_size": 800, - "metric": "IC", + "metric": "loss", "loss": "mse", - "base_model": "GRU", + "base_model": "LSTM", + "with_pretrain": True, "seed": 0, "GPU": 0, }, diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 22ed6812d..77e3b9de9 100755 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -55,6 +55,7 @@ class GAT(Model): early_stop=20, loss="mse", base_model="GRU", + with_pretrain=True, optimizer="adam", GPU="0", seed=0, @@ -77,6 +78,7 @@ class GAT(Model): self.optimizer = optimizer.lower() self.loss = loss self.base_model = base_model + self.with_pretrain = with_pretrain self.visible_GPU = GPU self.use_gpu = torch.cuda.is_available() self.seed = seed @@ -95,6 +97,7 @@ class GAT(Model): "\noptimizer : {}" "\nloss_type : {}" "\nbase_model : {}" + "\nwith_pretrain : {}" "\nvisible_GPU : {}" "\nuse_GPU : {}" "\nseed : {}".format( @@ -110,6 +113,7 @@ class GAT(Model): optimizer.lower(), loss, base_model, + with_pretrain, GPU, self.use_gpu, seed, @@ -256,6 +260,23 @@ class GAT(Model): evals_result["train"] = [] evals_result["valid"] = [] + # load pretrained base_model + if self.with_pretrain: + self.logger.info("Loading pretrained model...") + if self.base_model == "LSTM": + from ...contrib.model.pytorch_lstm import LSTMModel + pretrained_model = LSTMModel() + pretrained_model.load_state_dict(torch.load('benchmarks/LSTM/model_lstm_csi300.pkl')) + elif self.base_model == "GRU": + from ...contrib.model.pytorch_gru import GRUModel + pretrained_model = GRUModel() + pretrained_model.load_state_dict(torch.load('benchmarks/GRU/model_gru_csi300.pkl')) + model_dict = self.GAT_model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} + model_dict.update(pretrained_dict) + self.GAT_model.load_state_dict(model_dict) + self.logger.info("Loading pretrained model Done...") + # train self.logger.info("training...") self._fitted = True