注意力头

注意力机制(Attention Mechanism)是现代深度学习中一个非常重要的概念,尤其在自然语言处理(NLP)和计算机视觉(CV)领域中广泛应用。注意力头(Attention Head)是多头注意力机制(Multi-Head Attention)中的一个组成部分,用于从不同的子空间中提取信息,从而提高模型的表达能力和灵活性。

1. 注意力机制的基本原理

注意力机制的核心思想是让模型在处理输入数据时,能够动态地关注到更重要的部分。这类似于人类在阅读时,会更关注某些关键词或句子,而不是平均地处理所有内容。

1.1 单头注意力(Single-Head Attention)

单头注意力机制通过计算输入序列中每个元素之间的相关性(或相似性),生成一个注意力权重矩阵,然后根据这些权重对输入序列进行加权求和,得到输出。
具体来说,单头注意力机制的计算步骤如下:
  1. 计算查询(Query)、键(Key)和值(Value):
    Q=XWQ,K=XWK,V=XWV
    其中,X 是输入序列,WQ、WK 和 WV 是可训练的权重矩阵。
  2. 计算注意力分数:
    Attention Scores=softmax(dkQKT)
    其中,dk 是键向量的维度,用于缩放分数,防止梯度消失。
  3. 计算加权求和:
    Output=Attention Scores×V

2. 多头注意力(Multi-Head Attention)

多头注意力机制通过将输入序列分成多个不同的子空间(或“头”),分别计算注意力,然后将这些结果拼接起来,从而提高模型的表达能力和灵活性。

2.1 多头注意力的计算步骤

  1. 将输入序列分成多个头:
    Qi=XWQi,Ki=XWKi,Vi=XWVifor i=1,2,,h
    其中,h 是头的数量,WQiWKiWVi 是每个头的可训练权重矩阵。
  2. 分别计算每个头的注意力分数:
    Attention Scoresi=softmax(dkQiKiT)
  3. 分别计算每个头的加权求和:
    Outputi=Attention Scoresi×Vi
  4. 将所有头的结果拼接起来:
    Concatenated Output=Concat(Output1,Output2,,Outputh)
  5. 通过一个线性变换层:
    Final Output=Concatenated OutputWO
    其中,WO 是另一个可训练的权重矩阵。

3. 注意力头的作用

每个注意力头可以关注到输入序列中不同的部分,从而捕捉到不同类型的特征。通过将多个头的结果拼接起来,模型能够更全面地理解输入序列,从而提高其表达能力和灵活性。

3.1 举例说明

假设输入序列是一个句子,每个单词都有一个对应的向量表示。通过多头注意力机制,不同的头可以关注到句子中的不同单词或短语,从而捕捉到句子中的不同语义信息。

4. 实现示例

4.1 PyTorch 实现

以下是一个简单的多头注意力机制的实现示例:
Python
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size = x.size(0)

        # 将输入序列分成多个头
        Q = self.query(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_scores = F.softmax(attention_scores, dim=-1)

        # 计算加权求和
        attention_output = torch.matmul(attention_scores, V).transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

        # 通过线性变换层
        output = self.out(attention_output)
        return output

# 示例使用
embed_dim = 64
num_heads = 4
input_tensor = torch.randn(32, 10, embed_dim)  # 假设批量大小为 32,序列长度为 10
attention = MultiHeadAttention(embed_dim, num_heads)
output = attention(input_tensor)
print(output.shape)  # 输出形状应为 [32, 10, 64]
 

5. 注意力头的数量

注意力头的数量 h 是一个超参数,需要通过实验来选择合适的值。常见的值有 8、16、32 等。注意力头的数量越多,模型的表达能力越强,但计算量也会相应增加。

6. 总结

注意力头是多头注意力机制中的一个组成部分,通过将输入序列分成多个子空间,分别计算注意力,然后将这些结果拼接起来,从而提高模型的表达能力和灵活性。注意力头在自然语言处理和计算机视觉中被广泛应用,能够显著提高模型的性能和效率。
posted @ 2025-08-08 18:35  yinghualeihenmei  阅读(71)  评论(0)    收藏  举报