From 666e1ffcbdb7d1dfa0537b033d5070ecfc2c9a18 Mon Sep 17 00:00:00 2001 From: lwwang1995 Date: Mon, 7 Dec 2020 21:17:02 +0800 Subject: [PATCH] Update settings. --- .../benchmarks/SFM/workflow_config_sfm_Alpha158.yaml | 2 +- qlib/contrib/model/pytorch_gats_ts.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/benchmarks/SFM/workflow_config_sfm_Alpha158.yaml b/examples/benchmarks/SFM/workflow_config_sfm_Alpha158.yaml index 7c7775c55..cd00aadec 100755 --- a/examples/benchmarks/SFM/workflow_config_sfm_Alpha158.yaml +++ b/examples/benchmarks/SFM/workflow_config_sfm_Alpha158.yaml @@ -57,7 +57,7 @@ task: num_layers: 2 dropout: 0.0 n_epochs: 200 - lr: 1e-1 + lr: 5e-2 early_stop: 10 batch_size: 800 metric: loss diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index b20e57d45..76b94aa21 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel class DailyBatchSampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source self.data = self.data_source.data.loc[self.data_source.get_index()] - self.daily_count = self.data.groupby(level=0).size().values - self.daily_index = np.roll(np.cumsum(self.daily_count), 1) + self.daily_count = self.data.groupby(level=0).size().values[1:] + self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:] def __iter__(self): for idx, count in zip(self.daily_index, self.daily_count): - yield slice(idx, idx + count) + yield np.arange(idx, idx + count) def __len__(self): return len(self.data_source) @@ -202,6 +204,8 @@ class GATs(Model): self.GAT_model.train() for data in data_loader: + + data = data.squeeze() feature = data[:, :, 0:-1].to(self.device) label = data[:, -1, -1].to(self.device) @@ -222,6 +226,7 @@ class GATs(Model): for data in data_loader: + data = data.squeeze() feature = data[:, :, 0:-1].to(self.device) # feature[torch.isnan(feature)] = 0 label = data[:, -1, -1].to(self.device) @@ -335,6 +340,7 @@ class GATs(Model): for data in test_loader: + data = data.squeeze() feature = data[:, :, 0:-1].to(self.device) with torch.no_grad():