1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix bugs for models.

This commit is contained in:
lwwang1995
2020-12-07 19:50:00 +08:00
committed by you-n-g
parent 4a748525bc
commit 70fb760830
2 changed files with 8 additions and 6 deletions

View File

@@ -32,7 +32,6 @@ from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
@@ -41,7 +40,7 @@ class DailyBatchSampler(Sampler):
def __iter__(self):
for idx, count in zip(self.daily_index, self.daily_count):
yield slice(idx, idx+count)
yield slice(idx, idx + count)
def __len__(self):
return len(self.data_source)
@@ -272,10 +271,14 @@ class GATs(Model):
raise ValueError("the path of the pretrained model should be given first!")
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
pretrained_model = LSTMModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
elif self.base_model == "GRU":
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
pretrained_model = GRUModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
else:
raise ValueError("unknown base model name `%s`" % self.base_model)

View File

@@ -238,7 +238,7 @@ class TSDataSampler:
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
# self.index_link = self.build_link(self.data)
self.idx_df, self.idx_map = self.build_index(self.data)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
def get_index(self):
"""
@@ -368,7 +368,6 @@ class TSDataSampler:
else:
indices = self._get_indices(*self._get_row_col(idx))
# 1) for better performance, use the last nan line for padding the lost date
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
# precision problems. It will not cause any problems in my tests at least