mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix bugs for models.
This commit is contained in:
@@ -32,7 +32,6 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
|
||||
def __init__(self, data_source):
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.data.loc[self.data_source.get_index()]
|
||||
@@ -41,7 +40,7 @@ class DailyBatchSampler(Sampler):
|
||||
|
||||
def __iter__(self):
|
||||
for idx, count in zip(self.daily_index, self.daily_count):
|
||||
yield slice(idx, idx+count)
|
||||
yield slice(idx, idx + count)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
@@ -272,10 +271,14 @@ class GATs(Model):
|
||||
raise ValueError("the path of the pretrained model should be given first!")
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
pretrained_model = LSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
pretrained_model = GRUModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
@@ -238,7 +238,7 @@ class TSDataSampler:
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
@@ -368,7 +368,6 @@ class TSDataSampler:
|
||||
else:
|
||||
indices = self._get_indices(*self._get_row_col(idx))
|
||||
|
||||
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
|
||||
Reference in New Issue
Block a user