diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 6fb455b76..b20e57d45 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -32,7 +32,6 @@ from ...contrib.model.pytorch_gru import GRUModel class DailyBatchSampler(Sampler): - def __init__(self, data_source): self.data_source = data_source self.data = self.data_source.data.loc[self.data_source.get_index()] @@ -41,7 +40,7 @@ class DailyBatchSampler(Sampler): def __iter__(self): for idx, count in zip(self.daily_index, self.daily_count): - yield slice(idx, idx+count) + yield slice(idx, idx + count) def __len__(self): return len(self.data_source) @@ -272,10 +271,14 @@ class GATs(Model): raise ValueError("the path of the pretrained model should be given first!") self.logger.info("Loading pretrained model...") if self.base_model == "LSTM": - pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers) + pretrained_model = LSTMModel( + d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers + ) pretrained_model.load_state_dict(torch.load(self.model_path)) elif self.base_model == "GRU": - pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers) + pretrained_model = GRUModel( + d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers + ) pretrained_model.load_state_dict(torch.load(self.model_path)) else: raise ValueError("unknown base model name `%s`" % self.base_model) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index a07f7ab8f..96e4a6e41 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -238,7 +238,7 @@ class TSDataSampler: self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) # self.index_link = self.build_link(self.data) self.idx_df, self.idx_map = self.build_index(self.data) - self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance def get_index(self): """ @@ -368,7 +368,6 @@ class TSDataSampler: else: indices = self._get_indices(*self._get_row_col(idx)) - # 1) for better performance, use the last nan line for padding the lost date # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in # precision problems. It will not cause any problems in my tests at least