10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation (d2l.ai)
Multi-Head Attention | 算法 + 代码_哔哩哔哩_bilibili
代码实现
x[1,4,2] 1几个样本(句子) 4 预测步长(4个单词) 2每个单词的编码后特征长度
from math import sqrt import torch import torch.nn as nn class MultiHeadSelfAttention(nn.Module): def __init__(self, dim_in, d_model, num_heads=3): super(MultiHeadSelfAttention, self).__init__() self.dim_in = dim_in self.d_model = d_model self.num_heads = num_heads # 维度必须能被num_head 整除 assert d_model % num_heads == 0, "d_model must be multiple of num_heads" # 定义线性变换矩阵 self.linear_q = nn.Linear(dim_in, d_model) self.linear_k = nn.Linear(dim_in, d_model) self.linear_v = nn.Linear(dim_in, d_model) self.scale = 1 / sqrt(d_model // num_heads) # 最后的线性层 self.fc = nn.Linear(d_model, d_model) def forward(self, x): # x: tensor of shape (batch, n, dim_in) batch, n, dim_in = x.shape assert dim_in == self.dim_in nh = self.num_heads dk = self.d_model // nh # dim_k of each head q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk) k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk) v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk) dist = torch.matmul(q, k.transpose(2, 3)) * self.scale # batch, nh, n, n dist = torch.softmax(dist, dim=-1) # batch, nh, n, n att = torch.matmul(dist, v) # batch, nh, n, dv att = att.transpose(1, 2).reshape(batch, n, self.d_model) # batch, n, dim_v # 最后通过一个线性层进行变换 output = self.fc(att) return output x = torch.rand((1, 4, 2)) multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3) # (6, 3) output = multi_head_att(x)
101. 101 - 101 Multi-head的作用_哔哩哔哩_bilibili
最后一层处理下 压缩下维度