1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 09:31:18 +08:00
This commit is contained in:
Yuchen Fang
2021-01-28 00:41:22 +08:00
parent a03b08bb4c
commit 4dc14a2489

View File

@@ -20,15 +20,10 @@ class OPD_Extractor(nn.Module):
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
self.raw_fc = nn.Sequential(
nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),
)
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
self.fc = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 32),
nn.ReLU(),
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
)
def forward(self, inp):