mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 09:31:18 +08:00
192 lines
6.4 KiB
Python
192 lines
6.4 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from copy import deepcopy
|
|
import sys
|
|
|
|
from tianshou.data import to_torch
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
|
|
|
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
|
|
|
def forward(self, value, key):
|
|
key = key.unsqueeze(dim=1)
|
|
length = value.shape[1]
|
|
key = key.repeat([1, length, 1])
|
|
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze() # B * l
|
|
weight = weight.softmax(dim=-1).unsqueeze(dim=-1) # B * l * 1
|
|
out = (value * weight).sum(dim=1)
|
|
out = self.fc(out)
|
|
return out
|
|
|
|
|
|
class MaskAttention(nn.Module):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
|
|
|
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
|
|
|
def forward(self, value, key, seq_len, maxlen=9):
|
|
# seq_len: (batch,)
|
|
device = value.device
|
|
key = key.unsqueeze(dim=1)
|
|
length = value.shape[1]
|
|
key = key.repeat([1, length, 1]) # (batch, 9, 64)
|
|
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1) # (batch, 9)
|
|
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
|
weight[~mask] = float("-inf")
|
|
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
|
out = (value * weight).sum(dim=1)
|
|
out = self.fc(out)
|
|
return out
|
|
|
|
|
|
class TFMaskAttention(nn.Module):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
|
|
|
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
|
|
|
def forward(self, value, key, seq_len, maxlen=9):
|
|
device = value.device
|
|
key = key.unsqueeze(dim=1)
|
|
length = value.shape[1]
|
|
key = key.repeat([1, length, 1])
|
|
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1)
|
|
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
|
mask = mask.repeat(1, 3) # (batch, 9*3)
|
|
weight[~mask] = float("-inf")
|
|
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
|
out = (value * weight).sum(dim=1)
|
|
out = self.fc(out)
|
|
return out
|
|
|
|
|
|
class NNAttention(nn.Module):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.q_net = nn.Linear(in_dim, out_dim)
|
|
self.k_net = nn.Linear(in_dim, out_dim)
|
|
self.v_net = nn.Linear(in_dim, out_dim)
|
|
|
|
def forward(self, Q, K, V):
|
|
q = self.q_net(Q)
|
|
k = self.k_net(K)
|
|
v = self.v_net(V)
|
|
|
|
attn = torch.einsum("ijk,ilk->ijl", q, k)
|
|
attn = attn.to(Q.device)
|
|
attn_prob = torch.softmax(attn, dim=-1)
|
|
|
|
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
|
|
|
|
return attn_vec
|
|
|
|
|
|
class Reshape(nn.Module):
|
|
def __init__(self, *args):
|
|
super(Reshape, self).__init__()
|
|
self.shape = args
|
|
|
|
def forward(self, x):
|
|
return x.view(self.shape)
|
|
|
|
|
|
class DARNN(nn.Module):
|
|
def __init__(self, device="cpu", **kargs):
|
|
super().__init__()
|
|
self.emb_dim = kargs["emb_dim"]
|
|
self.hidden_size = kargs["hidden_size"]
|
|
self.num_layers = kargs["num_layers"]
|
|
self.is_bidir = kargs["is_bidir"]
|
|
self.dropout = kargs["dropout"]
|
|
self.seq_len = kargs["seq_len"]
|
|
self.interval = kargs["interval"]
|
|
self.today_length = 238
|
|
self.prev_length = 240
|
|
self.input_length = 480
|
|
self.input_size = 6
|
|
|
|
self.rnn = nn.LSTM(
|
|
input_size=self.input_size + self.emb_dim,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
batch_first=True,
|
|
bidirectional=self.is_bidir,
|
|
dropout=self.dropout,
|
|
)
|
|
self.prev_rnn = nn.LSTM(
|
|
input_size=self.input_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
batch_first=True,
|
|
bidirectional=self.is_bidir,
|
|
dropout=self.dropout,
|
|
)
|
|
self.fc_out = nn.Linear(in_features=self.hidden_size * 2, out_features=1)
|
|
self.attention = NNAttention(self.hidden_size, self.hidden_size)
|
|
self.act_out = nn.Sigmoid()
|
|
if self.emb_dim != 0:
|
|
self.pos_emb = nn.Embedding(self.input_length, self.emb_dim)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
|
|
today_input = inputs[:, : self.today_length, :]
|
|
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
|
|
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
|
|
if self.emb_dim != 0:
|
|
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
|
|
embedding = embedding.repeat([today_input.size()[0], 1, 1])
|
|
today_input = torch.cat((today_input, embedding), dim=-1)
|
|
prev_outs, _ = self.prev_rnn(prev_input)
|
|
today_outs, _ = self.rnn(today_input)
|
|
|
|
outs = self.attention(today_outs, prev_outs, prev_outs)
|
|
outs = torch.cat((today_outs, outs), dim=-1)
|
|
outs = outs[:, range(0, self.seq_len * self.interval, self.interval), :]
|
|
# outs = self.fc_out(outs).squeeze()
|
|
return self.act_out(self.fc_out(outs).squeeze(-1)), outs
|
|
|
|
|
|
class Transpose(nn.Module):
|
|
def __init__(self, dim1=0, dim2=1):
|
|
super().__init__()
|
|
self.dim1 = dim1
|
|
self.dim2 = dim2
|
|
|
|
def forward(self, x):
|
|
return x.transpose(self.dim1, self.dim2)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, *args, **kargs):
|
|
super().__init__()
|
|
self.attention = nn.MultiheadAttention(*args, **kargs)
|
|
|
|
def forward(self, x):
|
|
return self.attention(x, x, x)[0]
|
|
|
|
|
|
def onehot_enc(y, len):
|
|
y = y.unsqueeze(-1)
|
|
y_onehot = torch.zeros(y.shape[0], len)
|
|
# y_onehot.zero_()
|
|
y_onehot.scatter(1, y, 1)
|
|
return y_onehot
|
|
|
|
|
|
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
|
|
if maxlen is None:
|
|
maxlen = lengths.max()
|
|
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
|
|
mask.type(dtype)
|
|
return mask
|