1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 09:31:18 +08:00
Files
qlib/examples/trade/model/util.py
Yuchen Fang a03b08bb4c format
2021-01-28 00:41:02 +08:00

192 lines
6.4 KiB
Python

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
import sys
from tianshou.data import to_torch
class Attention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
def forward(self, value, key):
key = key.unsqueeze(dim=1)
length = value.shape[1]
key = key.repeat([1, length, 1])
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze() # B * l
weight = weight.softmax(dim=-1).unsqueeze(dim=-1) # B * l * 1
out = (value * weight).sum(dim=1)
out = self.fc(out)
return out
class MaskAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
def forward(self, value, key, seq_len, maxlen=9):
# seq_len: (batch,)
device = value.device
key = key.unsqueeze(dim=1)
length = value.shape[1]
key = key.repeat([1, length, 1]) # (batch, 9, 64)
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1) # (batch, 9)
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
weight[~mask] = float("-inf")
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
out = (value * weight).sum(dim=1)
out = self.fc(out)
return out
class TFMaskAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
def forward(self, value, key, seq_len, maxlen=9):
device = value.device
key = key.unsqueeze(dim=1)
length = value.shape[1]
key = key.repeat([1, length, 1])
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1)
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
mask = mask.repeat(1, 3) # (batch, 9*3)
weight[~mask] = float("-inf")
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
out = (value * weight).sum(dim=1)
out = self.fc(out)
return out
class NNAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.q_net = nn.Linear(in_dim, out_dim)
self.k_net = nn.Linear(in_dim, out_dim)
self.v_net = nn.Linear(in_dim, out_dim)
def forward(self, Q, K, V):
q = self.q_net(Q)
k = self.k_net(K)
v = self.v_net(V)
attn = torch.einsum("ijk,ilk->ijl", q, k)
attn = attn.to(Q.device)
attn_prob = torch.softmax(attn, dim=-1)
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
return attn_vec
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class DARNN(nn.Module):
def __init__(self, device="cpu", **kargs):
super().__init__()
self.emb_dim = kargs["emb_dim"]
self.hidden_size = kargs["hidden_size"]
self.num_layers = kargs["num_layers"]
self.is_bidir = kargs["is_bidir"]
self.dropout = kargs["dropout"]
self.seq_len = kargs["seq_len"]
self.interval = kargs["interval"]
self.today_length = 238
self.prev_length = 240
self.input_length = 480
self.input_size = 6
self.rnn = nn.LSTM(
input_size=self.input_size + self.emb_dim,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=self.is_bidir,
dropout=self.dropout,
)
self.prev_rnn = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=self.is_bidir,
dropout=self.dropout,
)
self.fc_out = nn.Linear(in_features=self.hidden_size * 2, out_features=1)
self.attention = NNAttention(self.hidden_size, self.hidden_size)
self.act_out = nn.Sigmoid()
if self.emb_dim != 0:
self.pos_emb = nn.Embedding(self.input_length, self.emb_dim)
def forward(self, inputs):
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
today_input = inputs[:, : self.today_length, :]
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
if self.emb_dim != 0:
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
embedding = embedding.repeat([today_input.size()[0], 1, 1])
today_input = torch.cat((today_input, embedding), dim=-1)
prev_outs, _ = self.prev_rnn(prev_input)
today_outs, _ = self.rnn(today_input)
outs = self.attention(today_outs, prev_outs, prev_outs)
outs = torch.cat((today_outs, outs), dim=-1)
outs = outs[:, range(0, self.seq_len * self.interval, self.interval), :]
# outs = self.fc_out(outs).squeeze()
return self.act_out(self.fc_out(outs).squeeze(-1)), outs
class Transpose(nn.Module):
def __init__(self, dim1=0, dim2=1):
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.transpose(self.dim1, self.dim2)
class SelfAttention(nn.Module):
def __init__(self, *args, **kargs):
super().__init__()
self.attention = nn.MultiheadAttention(*args, **kargs)
def forward(self, x):
return self.attention(x, x, x)[0]
def onehot_enc(y, len):
y = y.unsqueeze(-1)
y_onehot = torch.zeros(y.shape[0], len)
# y_onehot.zero_()
y_onehot.scatter(1, y, 1)
return y_onehot
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype)
return mask