注意力机制介绍及其代码实现

1. 介绍注意力机制

注意力机制(Attention Mechanism)是一种能够根据输入序列中不同位置的相关性,动态分配权重的计算方法。它的核心思想是:在处理某个位置的表示时,并不是仅依赖该位置本身的信息,而是参考整个序列中其他位置的信息,并根据相关性(权重)来融合这些信息。

在自然语言处理、计算机视觉等任务中,注意力机制能够让模型“聚焦”于与当前任务最相关的部分,从而提升特征表示能力。常见的注意力计算流程包括:

  1. 将输入向量映射为 Query (Q)Key (K)Value (V) 三组向量。
  2. 通过 Q 与 K 的点积计算相似度分数(score),并进行缩放(缩放点积注意力)。
  3. 对分数使用 softmax 转化为概率分布(权重)。
  4. 将权重与 V 相乘并求和,得到融合了上下文信息的表示。

公式表示为:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

2. 实现单头注意力代码

import torch
from torch import nn
import math
from torch.nn import functional as F
# 生成数据
sentences = ['我','喜欢','玩','游戏']
sentences_dict = {
    "我":0,
    "喜欢":1,
    "玩":2,
    "游戏":3
}
# 转成embedding
dim = 512
em = nn.Embedding(num_embeddings=len(sentences),embedding_dim=dim) # num_embeddings意思允许0-(n-1)的索引
sentences_em = em(torch.tensor([sentences_dict[idx] for idx in sentences]))
print("embedding初始数据:",sentences_em.shape)

# 得到QKV
fc = nn.ModuleList([nn.Linear(dim,dim) for _ in range(3)])
Q, K, V = fc[0](sentences_em),fc[1](sentences_em),fc[2](sentences_em)
print("V:",V.shape)

# 计算Q和K每个字之间的余弦关系,并通dim进行压缩范围,得到score
score = torch.matmul(Q,K.transpose(-1,-2))/math.sqrt(dim)
print("score:",score.shape)

# 将score归一化为score_sf
score_sf = F.softmax(score,-1)
print("score_sf:",score_sf.shape)

# 将score_sf与V信息融合为score_sf_v
score_sf_v = torch.matmul(score_sf,V)
print("score_sf_v:",score_sf_v.shape)

3. 实现多头注意力代码

import torch
from torch import nn
import math
from torch.nn import functional as F
# 生成数据
sentences = ['我','喜欢','玩','游戏']
sentences_dict = {
    "我":0,
    "喜欢":1,
    "玩":2,
    "游戏":3
}
# 转成embedding
dim = 512
em = nn.Embedding(num_embeddings=len(sentences),embedding_dim=dim) # num_embeddings意思允许0-(n-1)的索引
sentences_em = em(torch.tensor([sentences_dict[idx] for idx in sentences]))
print("embedding初始数据:",sentences_em.shape)

# 得到QKV
fc = nn.ModuleList([nn.Linear(dim,dim) for _ in range(3)])
Q, K, V = fc[0](sentences_em),fc[1](sentences_em),fc[2](sentences_em)
print("V:",V.shape)

# 拆分QKV
head_num = 8
fc_num = nn.ModuleList([nn.Linear(dim,dim//head_num) for _ in range(3*head_num)])
Q_num = torch.stack([fc_num[i](Q) for i in range(0 * head_num, 1 * head_num)])
K_num = torch.stack([fc_num[i](K) for i in range(1 * head_num, 2 * head_num)])
V_num = torch.stack([fc_num[i](V) for i in range(2 * head_num, 3 * head_num)])
print("V_num:",V_num.shape)

# 计算Q和K每个字之间的余弦关系,并通dim进行压缩范围,得到score
score = torch.matmul(Q_num,K_num.transpose(-1,-2))/math.sqrt(dim//head_num)
print("score:",score.shape)

# 将score归一化为score_sf
score_sf = F.softmax(score,-1)
print("score_sf:",score_sf.shape)

# 将score_sf与V信息融合为score_sf_v
score_sf_v = torch.matmul(score_sf,V_num)
print("score_sf_v:",score_sf_v.shape)

# 将拆分后的数据合并
# 注意必须由[4,8,64]变成[4,512],如果由[4,64,8]或者其他变成[4,512]将破坏head_num的连续性
score_sf_v_num = score_sf_v.transpose(0, 1).reshape(score_sf_v.shape[1], -1) 
print("score_sf_v_num:",score_sf_v_num.shape)

# 使用线性变换减少合并带来的数据割裂
score_sf_v_num_fc = nn.Linear(dim,dim)(score_sf_v_num)
print("score_sf_v_num_fc:", score_sf_v_num_fc.shape)

4. 总结与拓展

通过单头和多头注意力的实现,我们可以直观理解 Transformer 中核心的上下文信息融合过程:

  • 单头注意力:结构简单,便于理解,但在不同语义子空间的特征捕捉上能力有限。
  • 多头注意力:并行使用多个注意力头在不同的表示子空间中学习特征,并在最后拼接结果,从而增强模型表达能力。

进一步的拓展包括:

  • 添加 Mask:在自回归任务(如文本生成)中避免模型访问未来信息。
  • 位置编码(Positional Encoding):由于注意力本身不包含位置信息,需要额外注入位置信息以保留序列顺序特征。
  • 优化实现:使用合并 QKV 的线性映射、批量矩阵运算等方式提高计算效率。
  • 应用场景:注意力机制已经广泛应用于机器翻译、文本生成、图像识别、多模态任务等。
posted @ 2025-08-13 19:06  xxtl  阅读(116)  评论(0)    收藏  举报