diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index e013238fa..4fbe12fb7 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -102,7 +102,7 @@ class SFM_Model(nn.Module): i = self.inner_activation( x_i + torch.matmul(h_tm1 * B_U[0], self.U_i) - ) # not sure whether I am doing in the right unsquuze + ) ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste)) fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))