import torch
import torch.nn as nn
from math import sqrt
class SelfAttention(nn.Module):
def __init__(self, input_dim, dim_q_k, dim_v):
self.q = nn.Linear(input_dim, dim_q_k)
self.k = nn.Linear(input_dim, dim_q_k)
self.v = nn.Linear(input_dim, dim_v)
self.norm = sqrt(dim_q_k)
def forward(self, x):
# (b, s, input_dim) -> (b, s, dim_q_k)
q = self.q(x)
k = self.k(x)
# (b, s, input_dim) -> (b, s, dim_v)
v = self.v(x)
# (b, s, s)
d = torch.bmm(q, k.transpose(1, 2))/self.norm
softmax_out = torch.softmax(d, dim=-1)
# (b, s, dim_v)
attn_out = torch.bmm(softmax_out, v)
return attn_out
实际上是target attention
import torch
import torch.nn as nn
from math import sqrt
class CrossAttention(nn.Module):
def __init__(self, input_dim, dim_q_k, dim_v):
self.q = nn.Linear(input_dim, dim_q_k)
self.k = nn.Linear(input_dim, dim_q_k)
self.v = nn.Linear(input_dim, dim_v)
self.norm = sqrt(dim_q_k)
def forward(self, encoder_input, decoder_input):
# (b, s, input_dim) -> (b, s, dim_q_k)
q = self.q(encoder_input)
k = self.k(decoder_input)
# (b, s, input_dim) -> (b, s, dim_v)
v = self.v(decoder_input)
# (b, s, s)
a = torch.bmm(q, k.transpose(1, 2))/self.norm
softmax_out = torch.softmax(a, dim=-1)
# (b, s, dim_v)
cross_attn_out = torch.bmm(softmax_out, v)
return cross_attn_out
class MultiHeadSelfAttention(nn.Module):
def __init__(self, input_dim, head, dim_q_k, dim_v):
self.head = head
self.dim_per_head_q_k = dim_q_k // head
self.dim_per_head_v = dim_v // head
self.q = nn.Linear(input_dim, self.dim_per_head_q_k)
self.k = nn.Linear(input_dim, self.dim_per_head_q_k)
self.v = nn.Linear(input_dim, self.dim_per_head_v)
self.norm = sqrt(self.dim_per_head_q_k)
self.out = nn.Linear(dim_v, input_dim)
def forward(self, x):
# (b, s, word_dim)
b = x.size(0)
s = x.size(1)
# (b, s, word_dim) -> (b, head, s, dimdim_per_head_q_k)
q = self.q(x).view(b, -1, self.head, self.dim_per_head_q_k).transpose(1, 2)
k = self.k(x).view(b, -1, self.head, self.dim_per_head_q_k).transpose(1, 2)
v = self.v(x).view(b, -1, self.head, self.dim_per_head_v).transpose(1, 2)
# (b, head, s, s)
d = torch.matmul(q, k.transpose(2, 3)) // self.norm
softmax_out = torch.softmax(d, dim=-1)
# (b, head, s, dim_per_head_v)
attn_out = torch.matmul(softmax_out, v)
# (b, s, head, dim_per_head_v) -> (b, s, dim_v)
attn_out = attn_out.transpose(1, 2).contiguous().view(b, s, -1)
# (b, s, dim_v) -> # (b, s, input_dim)
out = self.out(attn_out)
return out