注意力机制(Attention Mechanism)讲解
注意力机制(Attention Mechanism)是深度学习中一种模仿人类认知能力的核心技术,通过动态分配权重来聚焦输入数据的关键部分。它在自然语言处理(NLP)、计算机视觉(CV)和多模态任务中广泛应用,尤其因Transformer模型的成功而成为现代AI的核心组件。以下是注意力机制的详细解析:
1. 核心思想
注意力机制的核心是让模型在处理输入时,能够动态选择性地关注重要信息,忽略无关内容。
- 类比人类行为:就像阅读文章时,我们会重点关注关键词句,忽略次要信息。
- 数学本质:通过计算输入元素之间的相关性权重,对信息进行加权融合。
2. 基本结构
注意力机制包含三个核心组件:
- 查询(Query, Q):当前需要生成输出的目标(如解码器的当前状态)。
- 键(Key, K):输入元素的标识,用于与查询计算相关性。
- 值(Value, V):输入元素的实际内容,根据权重聚合后生成输出。
3. 计算步骤
以缩放点积注意力(Scaled Dot-Product Attention)为例:
- 相似度计算:查询(Q)与键(K)的点积,度量相关性。
$ \text{Score} = Q \cdot K^T $
(为避免梯度消失,除以 $ \sqrt{d_k} $ 缩放,$ d_k $ 是键的维度) - 权重归一化:通过Softmax将得分转换为概率分布。
$ \text{Attention Weights} = \text{Softmax}(\text{Score}) $ - 加权求和:用权重对值(V)进行聚合,生成上下文向量。
$ \text{Output} = \text{Attention Weights} \cdot V $
4. 主要类型
4.1 自注意力(Self-Attention)
- 特点:Q、K、V来自同一输入序列,捕捉序列内部的长距离依赖。
- 应用场景:Transformer编码器处理文本或图像时,分析词与词、像素与像素的关系。
- 示例:句子中的代词(如“它”)通过自注意力找到指代的名词(如“苹果”)。
4.2 交叉注意力(Cross-Attention)
- 特点:Q来自目标序列,K、V来自另一源序列,实现跨序列交互。
- 应用场景:Transformer解码器生成输出时,关注编码器的输出(如机器翻译)。
4.3 多头注意力(Multi-Head Attention)
- 思想:将Q、K、V投影到多个子空间,并行计算多组注意力,增强模型表达能力。
- 计算步骤:
- 将Q、K、V拆分为$ h $个头(如8头)。
- 每个头独立计算注意力。
- 拼接所有头的输出,并通过线性层融合。
- 优势:允许模型同时关注不同位置和不同语义层面的信息。
5. 注意力机制的优势
- 长距离依赖建模:传统RNN/LSTM难以处理长序列,注意力机制直接计算任意两个位置的关联。
- 并行计算:与RNN的时序计算不同,注意力可并行处理所有位置,提升训练速度。
- 可解释性:注意力权重可视化为热力图,直观显示模型关注的位置(如翻译时关注源句子的哪些词)。
6. 经典应用案例
6.1 Transformer模型
- 编码器:使用自注意力处理输入序列(如源语言句子)。
- 解码器:先通过自注意力处理目标序列,再通过交叉注意力融合编码器信息。
6.2 BERT与GPT
- BERT:通过双向自注意力学习上下文表示。
- GPT:使用掩码自注意力(仅关注左侧上下文)生成文本。
6.3 图像分类(Vision Transformer, ViT)
- 将图像切分为块(Patch),通过自注意力建模块间关系,替代传统CNN。
7. 注意力机制的变体与改进
- 局部注意力(Local Attention):限制关注窗口,减少计算量(适用于长序列)。
- 稀疏注意力(Sparse Attention):仅计算部分位置的权重(如Longformer的滑动窗口)。
- 内存压缩注意力(Memory-Compressed Attention):对键值进行降采样(如Linformer)。
- 先验引导注意力:引入外部知识约束权重(如在VQA中结合目标检测框信息)。
8. 数学形式化与代码示例
数学公式
\[\text{Attention}(Q, K, V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
\]
PyTorch代码(多头注意力)
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value):
batch_size = query.size(0)
# 线性变换并拆分为多头
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# 拼接多头输出并融合
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(context)
return output
9. 挑战与未来方向
- 计算复杂度:注意力矩阵的$ O(n^2) $复杂度难以处理超长序列(如DNA数据)。
- 动态稀疏性:如何自动学习稀疏注意力模式,平衡效果与效率。
- 跨模态统一:设计通用注意力机制处理文本、图像、语音等多模态数据。
总结
注意力机制通过动态加权聚焦关键信息,突破了传统模型处理序列数据的局限性,成为Transformer、BERT、GPT等里程碑模型的核心。其核心价值在于:
- 灵活性:可扩展为自注意力、交叉注意力、多头注意力等变体。
- 可解释性:注意力权重提供模型决策依据的可视化解释。
- 通用性:适用于文本、图像、语音等多种任务,是构建通用AI的重要基础。
以下是对多头注意力代码的详细注释,包含每个变量的形状说明(假设输入 query, key, value 的形状为 (batch_size, seq_len, d_model)):
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.d_model = d_model # 模型维度,如512
self.num_heads = num_heads # 注意力头数,如8
self.head_dim = d_model // num_heads # 每个头的维度,如512/8=64
# 定义线性变换层
self.W_q = nn.Linear(d_model, d_model) # 查询变换
self.W_k = nn.Linear(d_model, d_model) # 键变换
self.W_v = nn.Linear(d_model, d_model) # 值变换
self.W_o = nn.Linear(d_model, d_model) # 输出融合层
def forward(self, query, key, value):
batch_size = query.size(0) # 批大小,如32
# 步骤1:线性变换 + 拆分为多头
# query形状: (batch_size, seq_len_q, d_model) → 例如 (32, 10, 512)
Q = self.W_q(query) # 形状: (batch_size, seq_len_q, d_model)
K = self.W_k(key) # 形状: (batch_size, seq_len_k, d_model)
V = self.W_v(value) # 形状: (batch_size, seq_len_v, d_model)
# 将Q/K/V拆分为多个头(通过view和transpose改变形状)
# 目标形状: (batch_size, num_heads, seq_len, head_dim)
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Q形状: (batch_size, num_heads, seq_len_q, head_dim) → (32, 8, 10, 64)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# K形状: (batch_size, num_heads, seq_len_k, head_dim) → (32, 8, 10, 64)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# V形状: (batch_size, num_heads, seq_len_v, head_dim) → (32, 8, 10, 64)
# 步骤2:计算缩放点积注意力
# Q与K的转置相乘,计算相似度
scores = torch.matmul(Q, K.transpose(-2, -1)) # 矩阵乘法
# scores形状: (batch_size, num_heads, seq_len_q, seq_len_k) → (32, 8, 10, 10)
scores = scores / (self.head_dim ** 0.5) # 缩放
# 计算注意力权重(Softmax归一化)
attn_weights = torch.softmax(scores, dim=-1)
# attn_weights形状: (batch_size, num_heads, seq_len_q, seq_len_k) → (32, 8, 10, 10)
# 用权重对V加权求和
context = torch.matmul(attn_weights, V)
# context形状: (batch_size, num_heads, seq_len_q, head_dim) → (32, 8, 10, 64)
# 步骤3:拼接多头结果
# 转置并重塑形状以合并多头
context = context.transpose(1, 2).contiguous()
# 转置后形状: (batch_size, seq_len_q, num_heads, head_dim) → (32, 10, 8, 64)
context = context.view(batch_size, -1, self.d_model)
# 合并后形状: (batch_size, seq_len_q, d_model) → (32, 10, 512)
# 步骤4:通过输出线性层融合特征
output = self.W_o(context)
# output形状: (batch_size, seq_len_q, d_model) → (32, 10, 512)
return output
关键形状变化说明
-
输入形状:
query:(batch_size, seq_len_q, d_model)key:(batch_size, seq_len_k, d_model)value:(batch_size, seq_len_v, d_model)
(注:seq_len_k和seq_len_v通常相同)
-
线性变换后:
Q/K/V: 保持与输入相同的形状(batch_size, seq_len, d_model)。
-
拆分多头后:
Q/K/V:(batch_size, num_heads, seq_len, head_dim)
(通过view和transpose实现维度重组)
-
注意力计算:
scores:(batch_size, num_heads, seq_len_q, seq_len_k)attn_weights: 与scores形状相同context:(batch_size, num_heads, seq_len_q, head_dim)
-
合并多头后:
context:(batch_size, seq_len_q, d_model)
(通过transpose和view恢复原始形状)
-
最终输出:
output:(batch_size, seq_len_q, d_model)(与输入query的序列长度一致)
示例数值
假设:
batch_size = 2seq_len_q = 5(目标序列长度)seq_len_k = seq_len_v = 10(源序列长度)d_model = 512num_heads = 8head_dim = 512 // 8 = 64
则各变量形状变化如下:
| 步骤 | 张量 | 形状 |
|---|---|---|
| 输入 | query |
(2, 5, 512) |
| 线性变换后的Q | Q |
(2, 5, 512) |
| 拆分多头后的Q | Q |
(2, 8, 5, 64) |
| 注意力权重 | attn_weights |
(2, 8, 5, 10) |
| 加权后的上下文 | context |
(2, 8, 5, 64) |
| 合并多头后的上下文 | context |
(2, 5, 512) |
| 最终输出 | output |
(2, 5, 512) |
通过这种形状注释,可以直观理解注意力机制中数据流的维度变化。

浙公网安备 33010602011771号