动手学习深度学习第十章——————注意力机制

注意力提示

注意力是稀缺的,而环境中的干扰注意力的信息却并不少。 比如人类的视觉神经系统大约每秒收到\(10^8\)位的信息, 这远远超过了大脑能够完全处理的水平。 幸运的是,人类的祖先已经从经验(也称为数据)中认识到 “并非感官的所有输入都是一样的”。 在整个人类历史中,这种只将注意力引向感兴趣的一小部分信息的能力, 使人类的大脑能够更明智地分配资源来生存、成长和社交, 例如发现天敌、找寻食物和伴侣。

2. 注意力机制的通用框架

所有注意力模型都围绕 “查询(Query)- 键(Key)- 值(Value)”(QKV)框架展开,而 “注意力提示” 正是 QKV 的对应关系:
查询(Query):对应 “自主性提示”(我要找什么?);
键(Key):对应 “非自主性提示”(数据本身有什么特征?);
值(Value):对应 “聚焦后要提取的信息”(找到关键信息后,需要用它做什么?)。
举个直观例子:
你(Query)想在电商评论区找 “手机续航” 相关评价(自主性提示)→ 评论里的关键词(Key,如 “续航、电池、待机”,非自主性提示)→ 包含这些关键词的评论内容(Value)就是你最终关注的信息。

3. 注意力的两个核心行为

注意力汇聚(Attention Pooling):将 “值” 根据 “查询与键的匹配度” 加权求和(核心操作);
注意力分数(Attention Score):衡量 “查询” 与 “键” 的匹配程度(如相似度、点积),分数越高,对应 “值” 的权重越大。

注意力的可视化

平均汇聚层可以被视为输入的加权平均值, 其中各输入的权重是一样的。 实际上,注意力汇聚得到的是加权平均的总和值, 其中权重是在给定的查询和不同的键之间计算得出的
代码实现

import torch
from d2l import torch as d2l
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

小结

注意力机制与全连接层或者汇聚层的区别源于增加的自主提示。

由于包含了自主性提示,注意力机制与全连接的层或汇聚层不同。

注意力机制通过注意力汇聚使选择偏向于值(感官输入),其中包含查询(自主性提示)和键(非自主性提示)。键和值是成对的。

可视化查询和键之间的注意力权重是可行的。

10.2注意力汇聚:Nadaraya-Watson 核回归

一、核心原理:注意力汇聚的数学表达

1.1 问题背景给定训练数据集

\(\(\{(x_1, y_1), (x_2, y_2), ..., (x_n, y_n)\}\)\),我们要预测任意输入 x 对应的输出 \(\(\hat{y}\)\)。传统的 “平均回归” 会直接取所有 \(\(y_i\) \)的均值,完全忽略 x 与 \(\(x_i\)\) 的相似度;而 Nadaraya-Watson 核回归则通过注意力权重区分样本的重要性:与 x 越相似的\( \(x_i\)\),权重越高,对 (\hat{y}) 的贡献越大;与 x 越疏远的 \(\(x_i\)\),权重越低,对\( \(\hat{y}\)\) 的贡献越小。

1.2 数学公式

(1)注意力权重(核函数)最常用的是高斯核(Gaussian Kernel)(也叫 RBF 核),用于计算样本 (x_i) 相对于预测点 x 的注意力权重:\(\(w_i(x) = \frac{\exp\left(-\frac{1}{2}(x - x_i)^2 / \sigma^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2 / \sigma^2\right)}\)\)分子:衡量 x 与 (x_i) 的相似度(距离越小,指数值越大,权重越高);分母:归一化项,确保所有权重之和为 1(注意力权重的基本要求);\((\sigma\)\):带宽参数,控制注意力的 “聚焦程度”——(\sigma) 越小,权重越集中在相似样本上(聚焦性强);\(\(\sigma\)\) 越大,权重越分散(接近平均)。
(2)注意力汇聚(预测输出)通过加权平均得到最终预测值(注意力汇聚的核心操作):\(\(\hat{y}(x) = \sum_{i=1}^n w_i(x) \cdot y_i\)\)这个公式的本质是:输出 = 注意力权重 × 样本值 的加权和,是注意力机制最基础的 “汇聚(Aggregation)” 操作。

1.3 简化版(无参数)vs 带参数版无参数版:

\(\(\sigma\)\) 手动设定(如固定为 1),完全依赖数据分布;带参数版:将 \(\(\sigma\)\) 作为可学习参数,让模型自动学习 “聚焦程度”(更贴近深度学习的用法)。

代码实现

import torch
from torch import nn
from d2l import torch as d2l

# ===================== 1. 生成训练/测试数据 =====================
n_train = 50  # 训练样本数量
# 生成训练集输入x_train:在[0,5)范围内随机采样50个点并排序(排序便于后续可视化)
# torch.rand(n_train) * 5 → 生成50个0~5的随机数;torch.sort()返回(values, indices),取values
x_train, _ = torch.sort(torch.rand(n_train) * 5)  

# 定义真实函数:用于生成标签的基准函数(非线性函数)
def f(x):
    return 2 * torch.sin(x) + x**0.8  # 2*sin(x) + x的0.8次方

# 生成训练集输出y_train:真实函数值 + 高斯噪声(均值0,标准差0.5)
# torch.normal(0.0, 0.5, (n_train,)) → 生成50个噪声值,模拟真实场景的观测误差
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  

# 生成测试集输入x_test:在[0,5)范围内以0.1为步长生成均匀序列(共50个点),用于测试模型泛化能力
x_test = torch.arange(0, 5, 0.1)  
# 测试集的真实输出y_truth:无噪声的真实函数值,用于对比模型预测效果
y_truth = f(x_test)  
# 测试样本数量
n_test = len(x_test)  
print("测试样本数:", n_test)  # 输出:50

# ===================== 2. 定义可视化函数 =====================
def plot_kernel_reg(y_hat):
    """
    可视化核回归的预测结果
    参数:
        y_hat: 模型对测试集x_test的预测值
    """
    # 绘制测试集的真实值和预测值曲线
    # x轴:x_test,y轴:[真实值y_truth, 预测值y_hat],图例标注Truth/Pred
    # xlim/ylim:限定坐标轴范围,让可视化更清晰
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    # 绘制训练集样本点(散点图),alpha=0.5设置透明度,避免遮挡曲线
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);  

# ===================== 3. 基准模型:均值回归(无注意力) =====================
# 最简单的基准模型:所有测试点的预测值都等于训练集y_train的均值(完全忽略x的信息)
# torch.repeat_interleave → 将均值重复n_test次,匹配测试集长度
y_hat = torch.repeat_interleave(y_train.mean(), n_test)  
# 可视化均值回归的预测结果
plot_kernel_reg(y_hat)  

# ===================== 4. 无参数Nadaraya-Watson核回归(注意力汇聚) =====================
# 构造重复的测试输入矩阵:用于批量计算每个测试点与所有训练点的距离
# X_repeat形状:(n_test, n_train) → 每一行都是相同的测试输入(如[0.0,0.0,...,0.0], [0.1,0.1,...,0.1])
# torch.repeat_interleave(x_test, n_train) → 将每个测试点重复n_train次,得到一维张量;
# reshape((-1, n_train)) → 重塑为n_test行、n_train列的矩阵
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))  

# 计算注意力权重:使用高斯核(softmax归一化的平方指数核)
# (X_repeat - x_train) → 每个测试点与所有训练点的距离(形状:n_test×n_train)
# -(...)**2 / 2 → 高斯核的指数部分(等价于 exp(-(x-q)²/(2σ²)) 中σ=1的情况)
# nn.functional.softmax(..., dim=1) → 按行归一化(dim=1),确保每行权重和为1
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)  

# 注意力汇聚:加权平均计算预测值
# torch.matmul(attention_weights, y_train) → 矩阵乘法(n_test×n_train)×(n_train,)=(n_test,)
# 每个测试点的预测值 = 所有训练点y_train的加权和,权重为注意力权重
y_hat = torch.matmul(attention_weights, y_train)  
# 可视化核回归的预测结果
plot_kernel_reg(y_hat)  

# 可视化注意力权重热力图:展示每个测试点对训练点的注意力分配
# unsqueeze(0).unsqueeze(0) → 增加两个维度,适配d2l.show_heatmaps的输入要求(4维:批量×通道×高度×宽度)
# xlabel/ylabel:标注坐标轴(训练输入/测试输入)
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

# ===================== 5. 补充:批量矩阵乘法(bmm)示例 =====================
# 演示torch.bmm的用法(后续带参数核回归会用到)
# X形状:(2, 1, 4) → 2个批量,每个批量是1行4列的矩阵
X = torch.ones((2, 1, 4))
# Y形状:(2, 4, 6) → 2个批量,每个批量是4行6列的矩阵
Y = torch.ones((2, 4, 6))
# bmm要求:第一个矩阵的最后一维 = 第二个矩阵的倒数第二维(4=4)
# 输出形状:(2, 1, 6) → 2个批量,每个批量是1行6列的矩阵
print("bmm结果形状:", torch.bmm(X, Y).shape)  

# 另一个bmm示例:权重×值
# weights形状:(2, 10) → 2个批量,每个批量是10个权重
weights = torch.ones((2, 10)) * 0.1  # 权重均为0.1(和为1)
# values形状:(2, 10) → 2个批量,每个批量是10个值
values = torch.arange(20.0).reshape((2, 10))  
# unsqueeze(1) → 权重变为(2,1,10);unsqueeze(-1) → 值变为(2,10,1)
# bmm结果形状:(2,1,1) → 每个批量的加权和(0.1×0 + 0.1×1 + ... + 0.1×9 = 4.5;0.1×10+...+0.1×19=14.5)
print("权重×值的bmm结果:", torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)))

# ===================== 6. 定义带参数的Nadaraya-Watson核回归模型 =====================
class NWKernelRegression(nn.Module):
    """
    带可学习参数的Nadaraya-Watson核回归模型
    核心:将高斯核的带宽参数σ纳入模型,通过梯度下降学习最优值
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 定义可学习参数w(对应高斯核的1/σ,确保w>0)
        # torch.rand((1,)) → 初始化值为0~1的随机数;requires_grad=True → 开启梯度计算
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))  

    def forward(self, queries, keys, values):
        """
        前向传播:计算注意力权重并执行注意力汇聚
        参数:
            queries: 查询点(待预测的输入,形状:(n_queries,))
            keys: 键(训练集输入,形状:(n_queries, n_keys))
            values: 值(训练集输出,形状:(n_queries, n_keys))
        返回:
            预测值(形状:(n_queries,))
        """
        # 扩展查询点维度:匹配键的形状,便于计算距离
        # queries原本是(n_queries,),repeat_interleave后变为(n_queries×n_keys,),reshape后为(n_queries, n_keys)
        # 每个查询点重复n_keys次,每行对应一个查询点的所有键
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        
        # 计算注意力权重:带可学习参数w的高斯核
        # (queries - keys) * self.w → 缩放距离(w等价于1/σ,控制注意力聚焦程度)
        # softmax(dim=1) → 按行归一化,确保每行权重和为1
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        
        # 注意力汇聚:批量矩阵乘法实现加权平均
        # unsqueeze(1) → 权重变为(n_queries, 1, n_keys);unsqueeze(-1) → 值变为(n_queries, n_keys, 1)
        # bmm结果为(n_queries, 1, 1),reshape(-1)后变为(n_queries,)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# ===================== 7. 构造留一法(Leave-One-Out)训练数据 =====================
# 构造重复的训练输入矩阵:X_tile形状(n_train, n_train),每行都是相同的训练输入
# 例如:第一行是[x1,x2,...,x50],第二行也是[x1,x2,...,x50],以此类推
X_tile = x_train.repeat((n_train, 1))
# 构造重复的训练输出矩阵:Y_tile形状(n_train, n_train),每行都是相同的训练输出
Y_tile = y_train.repeat((n_train, 1))

# 生成键(keys):排除当前行对应的训练样本(留一法,避免自预测)
# 1 - torch.eye(n_train) → 生成对角为0、其余为1的矩阵(bool类型)
# X_tile[bool_mask] → 筛选出非对角元素,reshape后为(n_train, n_train-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# 生成值(values):对应键的训练输出,形状(n_train, n_train-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

# ===================== 8. 训练带参数的核回归模型 =====================
# 初始化模型
net = NWKernelRegression()
# 定义损失函数:均方误差损失(MSELoss),reduction='none' → 保留每个样本的损失(不求和/平均)
loss = nn.MSELoss(reduction='none')
# 定义优化器:随机梯度下降(SGD),学习率0.5
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
# 初始化动画器:用于动态展示训练过程中的损失变化
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

# 训练5个epoch
for epoch in range(5):
    # 清零梯度(避免梯度累积)
    trainer.zero_grad()
    # 前向传播:用训练集x_train作为查询点,keys/values为留一法构造的键值对
    # 计算每个训练样本的预测值,并与真实值y_train计算损失
    l = loss(net(x_train, keys, values), y_train)
    # 损失求和并反向传播(因为reduction='none',需要sum()得到标量损失)
    l.sum().backward()
    # 更新模型参数(w)
    trainer.step()
    # 打印当前epoch的损失
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    # 将当前损失添加到动画器,动态更新可视化
    animator.add(epoch + 1, float(l.sum()))

# ===================== 9. 测试带参数的核回归模型 =====================
# 构造测试集的键:形状(n_test, n_train),每行都是相同的训练输入
keys = x_train.repeat((n_test, 1))
# 构造测试集的值:形状(n_test, n_train),每行都是相同的训练输出
values = y_train.repeat((n_test, 1))
# 模型预测:detach() → 分离计算图,避免梯度追踪;unsqueeze(1) → 适配可视化函数的输入形状
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
# 可视化测试集预测结果
plot_kernel_reg(y_hat)

# 可视化训练后模型的注意力权重热力图
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

小结

Nadaraya-Watson核回归是具有注意力机制的机器学习范例。

Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。

注意力汇聚可以分为非参数型和带参数型。

10.3 注意力评分函数

image

1.1 注意力机制三要素查询(Query, Q):

当前需要计算注意力的 “目标”(如 Transformer 中当前位置的词向量);键(Key, K):用于匹配查询的 “参考”(如所有位置的词向量);值(Value, V):需要被加权求和的 “内容”(通常与 Key 一一对应)。

1.2 注意力机制通用流程计算评分:

用「注意力评分函数」计算 Q 与每个 K 的相似度(评分值);归一化权重:用 softmax 将评分值转换为 0~1 的注意力权重(和为 1);加权求和:用注意力权重对 Value 加权平均,得到最终的注意力输出。公式总结:\(\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\text{score}(Q, K)}{d}\right) V\)\)其中:\(\(\text{score}(Q, K)\)\):注意力评分函数(核心);d:缩放因子(部分评分函数需要,避免梯度 / 数值问题);\(\(\text{softmax}\)\):归一化评分值为权重;V:值的加权求和。

二、常见注意力评分函数

1.掩蔽softmax操作

正如上面提到的,softmax操作用于输出一个概率分布作为注意力权重。 在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。 例如,为了在 9.5节中高效处理小批量数据集, 某些文本序列被填充了没有意义的特殊词元。 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度(即词元的个数), 以便在计算softmax时过滤掉超出指定范围的位置。 下面的masked_softmax函数 实现了这样的掩蔽softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。

#@save
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

2.. 加性注意力

一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。 给定查询\(\(\mathbf{q} \in \mathbb{R}^q\)\)和 键\(\(\mathbf{k} \in \mathbb{R}^k\)\), 加性注意力(additive attention)$的评分函数为

(10.3.3)\(\[a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},\]\)
其中可学习的参数是\(\(\mathbf W_q\in\mathbb R^{h\times q}\)、 \(\mathbf W_k\in\mathbb R^{h\times k}\)\)\(\(\mathbf w_v\in\mathbb R^{h}\)\)。 如 (10.3.3)所示, 将查询和键连结起来后输入到一个多层感知机(MLP)中, 感知机包含一个隐藏层,其隐藏单元数是一个超参数\(\(h\)\)。 通过使用\(\(\tanh\)\)作为激活函数,并且禁用偏置项。

3. 缩放点积注意力(Scaled Dot-Product Attention)

(1)数学原理最简洁的评分函数,直接计算 Q 与 K 的点积,再除以缩放因子\(\(\sqrt{d_k}\)(\(d_k\)\)是 Key 的维度):评分函数:\(\(\text{score}(q, k) = q^T k / \sqrt{d_k}\)\)注意力权重:\(\(w = \text{softmax}\left(q^T k / \sqrt{d_k}\right)\)\)核心原因:当 Key 维度\(\(d_k\)\)较大时,点积结果的方差会随\(\(d_k\)\)增大而增大,导致 softmax 输出趋近于 0/1(梯度消失),除以\(\(\sqrt{d_k}\)\)可将方差归一化到 1。

小结

将注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。

当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。

10.4 Bahdanau 注意力

一、Bahdanau 注意力的核心背景

1.1 传统 Seq2Seq 的痛点传统 Seq2Seq(编码器 - 解码器)模型:编码器将整个输入序列压缩为单个固定长度的隐藏状态(上下文向量);解码器仅依赖这个固定向量生成输出序列。问题:当输入序列较长时(如长句子翻译),固定长度向量无法完整保留输入序列的全部信息,导致解码器生成的结果丢失细节。1.2 Bahdanau 注意力的解决思路Bahdanau 注意力为解码器的每个时间步动态计算 “上下文向量”:解码器当前时间步的隐藏状态作为「查询(Query)」;编码器所有时间步的隐藏状态作为「键(Key)」和「值(Value)」;通过加性注意力评分函数计算查询与每个键的相似度,生成注意力权重;加权求和编码器隐藏状态(值),得到适配当前解码步的上下文向量;上下文向量与解码器输入拼接,共同参与解码预测。核心公式(加性注意力评分):\(\(\text{score}(s_{t-1}, h_i) = v^T \tanh(W_1 h_i + W_2 s_{t-1})\)\(\alpha_{ti} = \text{softmax}(\text{score}(s_{t-1}, h_i))\)\(c_t = \sum_{i=1}^n \alpha_{ti} h_i\)\)其中:\(\(s_{t-1}\)\):解码器第\( \(t-1\)\) 步的隐藏状态(Query);\(\(h_i\)\):编码器第 i 步的隐藏状态(Key/Value);\(\(\alpha_{ti}\)\):第 t 步对编码器第 i 步的注意力权重;\(\(c_t\)\):第 t 步的上下文向量(注意力汇聚结果)。

二、Bahdanau 注意力的结构拆解Bahdanau 注意力主要包含 3 个核心模块:

编码器:通常用双向 LSTM/GRU,输出所有时间步的隐藏状态(作为 Key/Value);注意力层:实现加性注意力评分、权重归一化、上下文向量计算;解码器:融合 “注意力上下文向量” 和 “解码器输入”,生成输出序列。

代码实现

import torch
from torch import nn
from d2l import torch as d2l

# ===================== 1. 定义注意力解码器基类 =====================
#@save  # d2l库的装饰器,标记为可保存的函数/类
class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口(抽象类)"""
    def __init__(self, **kwargs):
        # 继承自d2l.Decoder基类(Seq2Seq解码器的通用接口)
        super(AttentionDecoder, self).__init__(**kwargs)

    @property  # 装饰器:将方法转为属性,方便外部获取注意力权重
    def attention_weights(self):
        # 抽象方法:子类必须实现,用于返回注意力权重
        raise NotImplementedError  # 抛出未实现异常,强制子类重写

# ===================== 2. 实现带Bahdanau注意力的Seq2Seq解码器 =====================
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        """
        初始化带加性注意力的解码器
        参数说明:
            vocab_size: 目标语言词汇表大小
            embed_size: 词嵌入维度
            num_hiddens: RNN隐藏层维度
            num_layers: RNN层数
            dropout: dropout概率(防止过拟合)
        """
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        # 初始化加性注意力层(Bahdanau Attention)
        # 输入参数:query_dim, key_dim, value_dim, dropout → 均设为num_hiddens
        self.attention = d2l.AdditiveAttention(
            num_hiddens, num_hiddens, num_hiddens, dropout)
        # 词嵌入层:将目标语言的词索引转为词向量
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # GRU层:输入维度=词嵌入维度+注意力上下文维度(embed_size + num_hiddens)
        # 因为解码器每一步输入 = 词嵌入 + 注意力上下文向量
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        # 全连接层:将RNN输出映射到目标词汇表大小(用于预测下一个词)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        """
        初始化解码器状态(接收编码器输出)
        参数:
            enc_outputs: 编码器输出,包含(outputs, hidden_state)
                - outputs: (batch_size, num_steps, num_hiddens) 编码器所有时间步隐藏态
                - hidden_state: (num_layers, batch_size, num_hiddens) 编码器最后一层隐藏态
            enc_valid_lens: 编码器序列有效长度(用于屏蔽PAD token的注意力)
        返回:
            解码器初始状态:(enc_outputs_permute, hidden_state, enc_valid_lens)
        """
        # 拆分编码器输出为输出序列和最后隐藏态
        outputs, hidden_state = enc_outputs
        # 调整编码器输出维度:(batch_size, num_steps, num_hiddens) → (num_steps, batch_size, num_hiddens)
        # 适配后续RNN的输入格式(RNN默认输入格式:seq_len, batch, feature)
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        """
        解码器前向传播(核心逻辑)
        参数:
            X: 目标序列输入,形状 (batch_size, num_steps) → 如(4,7)
            state: 解码器初始状态(来自init_state)
        返回:
            outputs: 解码器输出(每个时间步的词汇概率),形状 (batch_size, num_steps, vocab_size)
            state: 更新后的解码器状态
        """
        # 从状态中拆分出编码器输出、编码器最后隐藏态、编码器有效长度
        enc_outputs, hidden_state, enc_valid_lens = state
        # 步骤1:目标序列词嵌入 → (batch_size, num_steps, embed_size)
        # 调整维度:permute(1,0,2) → (num_steps, batch_size, embed_size)
        # 目的:方便按时间步遍历(for x in X),每次取一个时间步的所有batch数据
        X = self.embedding(X).permute(1, 0, 2)
        
        # 初始化存储变量:
        # outputs: 存储每个时间步的RNN输出
        # self._attention_weights: 存储每个时间步的注意力权重
        outputs, self._attention_weights = [], []
        
        # 步骤2:按时间步遍历目标序列(逐词解码)
        for x in X:
            # x: 当前时间步的词嵌入,形状 (batch_size, embed_size)
            
            # 步骤3:构造注意力查询向量(Query)
            # hidden_state[-1]: 解码器上一步的最后一层隐藏态 → (batch_size, num_hiddens)
            # unsqueeze(1): 增加维度 → (batch_size, 1, num_hiddens)(适配注意力层输入)
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            
            # 步骤4:计算注意力上下文向量(Context Vector)
            # 调用加性注意力层:query(解码器隐藏态) + key/value(编码器输出)
            # enc_valid_lens: 屏蔽编码器PAD位置的注意力
            # context形状:(batch_size, 1, num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            
            # 步骤5:拼接当前词嵌入和上下文向量
            # x.unsqueeze(1): (batch_size, embed_size) → (batch_size, 1, embed_size)
            # cat(dim=-1): 在特征维度拼接 → (batch_size, 1, embed_size + num_hiddens)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            
            # 步骤6:GRU前向传播(更新解码器状态)
            # x.permute(1,0,2): (batch_size,1,embed+hidden) → (1, batch_size, embed+hidden)
            # 输入格式适配GRU(seq_len, batch, feature),这里seq_len=1(单步)
            # out: (1, batch_size, num_hiddens) → 当前步RNN输出
            # hidden_state: (num_layers, batch_size, num_hiddens) → 更新后的解码器隐藏态
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            
            # 存储当前步输出和注意力权重
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        
        # 步骤7:拼接所有时间步输出并映射到词汇表
        # torch.cat(outputs, dim=0): 拼接所有时间步 → (num_steps, batch_size, num_hiddens)
        # self.dense: 线性变换 → (num_steps, batch_size, vocab_size)
        outputs = self.dense(torch.cat(outputs, dim=0))
        
        # 步骤8:调整输出维度 → (batch_size, num_steps, vocab_size)
        outputs = outputs.permute(1, 0, 2)
        
        # 返回输出和更新后的状态
        return outputs, [enc_outputs, hidden_state, enc_valid_lens]

    @property
    def attention_weights(self):
        """实现基类的抽象属性,返回注意力权重"""
        return self._attention_weights

# ===================== 3. 测试解码器(验证维度正确性) =====================
# 初始化编码器:词汇表大小10,嵌入维度8,隐藏层16,层数2
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                             num_layers=2)
encoder.eval()  # 切换到评估模式(关闭dropout等训练层)

# 初始化解码器:参数与编码器匹配
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                                  num_layers=2)
decoder.eval()

# 构造测试输入X:(batch_size=4, num_steps=7),全0序列(模拟词索引)
X = torch.zeros((4, 7), dtype=torch.long)  

# 初始化解码器状态:传入编码器输出(encoder(X)),有效长度为None
state = decoder.init_state(encoder(X), None)

# 解码器前向传播
output, state = decoder(X, state)

# 打印各维度信息(验证正确性)
print("解码器输出形状:", output.shape)          # (4,7,10) → batch=4, steps=7, vocab=10
print("状态列表长度:", len(state))             # 3 → (enc_outputs, hidden_state, enc_valid_lens)
print("编码器输出形状:", state[0].shape)       # (7,4,16) → steps=7, batch=4, hidden=16
print("解码器隐藏态层数:", len(state[1]))      # 2 → 与num_layers一致
print("解码器单层隐藏态形状:", state[1][0].shape)  # (4,16) → batch=4, hidden=16

# ===================== 4. 训练英→法翻译模型 =====================
# 超参数设置
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10  # 批次大小64,序列长度10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()  # 学习率0.005,训练250轮,使用GPU

# 加载英法翻译数据集:返回数据迭代器、源语言词汇表(英语)、目标语言词汇表(法语)
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

# 初始化编码器(适配真实数据集的词汇表大小)
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
# 初始化解码器(适配目标语言词汇表)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
# 组合编码器-解码器为完整的Seq2Seq模型
net = d2l.EncoderDecoder(encoder, decoder)

# 训练模型:d2l封装的Seq2Seq训练函数(包含梯度下降、损失计算、验证等)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

# ===================== 5. 测试翻译效果(可视化注意力) =====================
# 测试句子对:英语→法语
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']

# 逐句测试翻译
for eng, fra in zip(engs, fras):
    # 预测翻译结果 + 获取注意力权重序列
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    # 计算BLEU分数(评估翻译质量,k=2表示考虑2-gram)
    print(f'{eng} => {translation},  bleu {d2l.bleu(translation, fra, k=2):.3f}')

# 处理注意力权重:可视化最后一个句子的注意力热力图
# 拼接所有时间步的注意力权重 → (num_steps_query, num_steps_key)
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((
    1, 1, -1, num_steps))

# 绘制注意力热力图:X轴=Key位置(英语输入),Y轴=Query位置(法语输出)
# 只显示到最后一个英语句子的长度+1(包含结束词元)
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

小结

预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。

在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

10.5 多头注意力

一、多头注意力的核心动机

1.1 单头注意力的局限单头注意力(如缩放点积注意力)只能捕捉单一维度的特征关联:

例如在机器翻译中,单头注意力可能只关注 “词汇匹配”,而忽略 “语法结构” 或 “位置关系”;无法同时捕捉短距离依赖(如相邻词)和长距离依赖(如句首句尾)。

1.2 多头注意力的解决思路通过 “分而治之” 的策略:

将 Q/K/V 分别通过 h 个独立的线性变换(投影),得到 h 组低维的 Q/K/V(维度为 \(\(d_k = d_{model}/h\))\);对每组低维 Q/K/V 独立计算缩放点积注意力(即 “多头”);将 h 个头的注意力输出拼接,再通过一次线性变换,得到最终输出。核心优势:每个头聚焦输入序列的不同特征维度(如头 1 关注词性、头 2 关注位置、头 3 关注语义);低维子空间计算注意力,降低单头高维计算的复杂度,提升并行效率。

代码实现

import math
import torch
from torch import nn
from d2l import torch as d2l

# ===================== 1. 定义多头注意力核心类 =====================
#@save  # d2l库装饰器,标记为可保存的函数/类
class MultiHeadAttention(nn.Module):
    """多头注意力实现(基于缩放点积注意力)"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        """
        初始化多头注意力层
        参数说明:
            key_size: 键(Key)的维度
            query_size: 查询(Query)的维度
            value_size: 值(Value)的维度
            num_hiddens: 模型隐藏层总维度(需能被num_heads整除)
            num_heads: 注意力头的数量
            dropout: dropout概率(防止过拟合)
            bias: 线性层是否使用偏置
        """
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads  # 保存注意力头数
        # 初始化缩放点积注意力(单头),d2l.DotProductAttention已实现缩放+softmax+dropout
        self.attention = d2l.DotProductAttention(dropout)
        
        # 定义4个线性层:将Q/K/V投影到num_hiddens维度,最终输出也投影到num_hiddens
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)  # Query投影层
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)    # Key投影层
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)  # Value投影层
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)  # 多头结果拼接后的投影层

    def forward(self, queries, keys, values, valid_lens):
        """
        多头注意力前向传播(核心逻辑)
        参数:
            queries: 查询,形状 (batch_size, num_queries, query_size)
            keys: 键,形状 (batch_size, num_kvpairs, key_size)
            values: 值,形状 (batch_size, num_kvpairs, value_size)
            valid_lens: 有效长度(屏蔽PAD token),形状 (batch_size,) 或 (batch_size, num_queries)
        返回:
            output: 多头注意力输出,形状 (batch_size, num_queries, num_hiddens)
        """
        # 步骤1:线性投影 + 维度变换(拆分注意力头,适配并行计算)
        # transpose_qkv作用:将投影后的Q/K/V拆分为num_heads个头,形状变为 (batch_size*num_heads, 序列长度, num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        # 步骤2:处理有效长度(适配多头)
        if valid_lens is not None:
            # valid_lens原本形状:(batch_size,) 或 (batch_size, num_queries)
            # 重复num_heads次:每个头使用相同的有效长度(如[3,2] → [3,3,3,3,3,2,2,2,2,2],num_heads=5)
            # dim=0:按批次维度重复,确保每个头的有效长度与原批次对应
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # 步骤3:计算单头注意力(所有头并行计算)
        # output形状:(batch_size*num_heads, num_queries, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # 步骤4:拼接多头结果(逆转transpose_qkv的维度变换)
        # output_concat形状:(batch_size, num_queries, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        
        # 步骤5:最终线性投影(融合多头信息)
        return self.W_o(output_concat)

# ===================== 2. 辅助函数:拆分Q/K/V为多头(并行计算) =====================
#@save
def transpose_qkv(X, num_heads):
    """
    为多注意力头的并行计算变换Q/K/V的形状
    核心:将num_hiddens拆分为num_heads × (num_hiddens/num_heads),并调整维度顺序
    参数:
        X: 输入张量,形状 (batch_size, seq_len, num_hiddens)
        num_heads: 注意力头数
    返回:
        变换后的张量,形状 (batch_size*num_heads, seq_len, num_hiddens/num_heads)
    """
    # 步骤1:拆分最后一维为num_heads × (num_hiddens/num_heads)
    # X形状:(batch_size, seq_len, num_hiddens) → (batch_size, seq_len, num_heads, num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 步骤2:调整维度顺序 → 将num_heads维度放到batch_size之后,方便并行计算
    # X形状:(batch_size, num_heads, seq_len, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 步骤3:合并batch_size和num_heads维度 → 模拟“每个头为一个独立批次”
    # X形状:(batch_size*num_heads, seq_len, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

# ===================== 3. 辅助函数:拼接多头结果(逆转transpose_qkv) =====================
#@save
def transpose_output(X, num_heads):
    """
    逆转transpose_qkv的操作,将多头结果拼接回原维度
    参数:
        X: 多头注意力输出,形状 (batch_size*num_heads, seq_len, num_hiddens/num_heads)
        num_heads: 注意力头数
    返回:
        拼接后的张量,形状 (batch_size, seq_len, num_hiddens)
    """
    # 步骤1:拆分batch_size*num_heads维度 → (batch_size, num_heads, seq_len, num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    
    # 步骤2:调整维度顺序 → 恢复为 (batch_size, seq_len, num_heads, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    
    # 步骤3:拼接num_heads和num_hiddens/num_heads维度 → (batch_size, seq_len, num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)

# ===================== 4. 测试多头注意力(验证维度正确性) =====================
# 超参数设置:隐藏层总维度100,注意力头数5(100/5=20,每个头维度20)
num_hiddens, num_heads = 100, 5
# 初始化多头注意力层(Q/K/V维度均为100,隐藏层100,头数5,dropout=0.5)
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()  # 切换到评估模式(关闭dropout)

# 构造测试数据
batch_size, num_queries = 2, 4  # 批次大小2,查询数4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])  # 键值对数量6,有效长度[3,2](batch0的键前3个有效,batch1的键前2个有效)
X = torch.ones((batch_size, num_queries, num_hiddens))  # 查询张量 (2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  # 键/值张量 (2,6,100)

# 前向传播计算多头注意力
output = attention(X, Y, Y, valid_lens)

# 打印输出形状(验证正确性)
print("多头注意力输出形状:", output.shape)  # 预期:(2,4,100) → batch_size=2, num_queries=4, num_hiddens=100

小结

多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

基于适当的张量操作,可以实现多头注意力的并行计算。

10.6自注意力和位置编码

在深度学习中,经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。 想象一下,有了注意力机制之后,我们将词元序列输入注意力池化中, 以便同一组词元同时充当查询、键和值。 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。 由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention)
自注意力的优势并行计算:无需像 RNN 逐时间步计算,可一次性处理整个序列;长距离依赖:直接建模任意两个位置的依赖关系,无梯度消失问题;全局视野:每个位置能关注到序列所有位置的信息(而非 RNN 的局部视野)。自注意力的数学公式与多头注意力公式一致,仅需满足 Q=K=V:\(\(\text{SelfAttention}(X) = \text{MultiHead}(X, X, X)\)其中 X 是输入序列(形状:\((\text{batch_size}, \text{seq_len}, d_{\text{model}})\))\)。1.2 自注意力的关键细节:掩码(Mask)自注意力需根据场景添加掩码,避免无效信息干扰:PAD 掩码:屏蔽填充(PAD)token 的注意力(所有场景);未来掩码(Look-Ahead Mask):解码器自注意力专用,屏蔽当前位置之后的所有位置(防止 “看到未来信息”)。二、位置编码(Positional Encoding)2.1 核心动机自注意力本身不包含序列位置信息(输入序列打乱后,自注意力输出不变),而序列的位置对语义至关重要(如 “我吃苹果”≠“苹果吃我”)。位置编码的作用:为每个位置生成唯一的 “位置向量”,与词向量相加,让模型感知序列的位置关系。2.2 位置编码的实现方式Transformer 采用正弦余弦位置编码(可无限扩展,适配任意长度序列),核心公式:(\begin{align}
PE_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \
PE_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
\end{align
})其中:pos:序列中的位置索引(从 0 开始);i:位置向量的维度索引(从 0 开始);(d_{\text{model}}):模型总维度(如 512);偶数维度用正弦函数,奇数维度用余弦函数。核心特性唯一性:不同位置的编码向量唯一;可解释性:位置差的编码可通过三角函数公式推导(如 (PE_{pos+k}) 可由 (PE_{pos}) 线性表示);无限扩展:可计算任意位置(即使训练时未见过的长度)的编码。

代码实现

import math
import torch
from torch import nn
from d2l import torch as d2l

# ===================== 1. 测试多头自注意力(Self-Attention) =====================
# 超参数设置:模型隐藏层维度=100,注意力头数=5
num_hiddens, num_heads = 100, 5
# 初始化多头注意力层(d2l库已实现的MultiHeadAttention)
# 参数:key_size=100, query_size=100, value_size=100, num_hiddens=100, num_heads=5, dropout=0.5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()  # 切换到评估模式(关闭dropout等训练层)

# 构造测试数据
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])  # 批次=2,查询数=4,有效长度=[3,2]
X = torch.ones((batch_size, num_queries, num_hiddens))  # 输入张量:(2,4,100) → 模拟自注意力输入(Q=K=V=X)

# 计算多头自注意力(自注意力:Q=K=V=X)
# valid_lens:屏蔽PAD token,batch0前3个位置有效,batch1前2个位置有效
output = attention(X, X, X, valid_lens)
print("多头自注意力输出形状:", output.shape)  # 输出:(2,4,100) → batch×查询数×隐藏层维度

# ===================== 2. 定义位置编码类(正弦余弦编码) =====================
#@save  # d2l库装饰器,标记为可保存的函数/类
class PositionalEncoding(nn.Module):
    """位置编码:为序列注入位置信息(弥补自注意力无位置感知的缺陷)"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        """
        初始化位置编码
        参数:
            num_hiddens: 编码维度(需与词向量维度一致)
            dropout: dropout概率(防止过拟合)
            max_len: 预计算的最大序列长度(适配超长序列)
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)  # Dropout层
        
        # 步骤1:创建位置编码矩阵P,形状 (1, max_len, num_hiddens)
        # 1表示batch维度(广播复用),max_len是最大序列长度,num_hiddens是编码维度
        self.P = torch.zeros((1, max_len, num_hiddens))
        
        # 步骤2:计算位置编码的核心值X
        # X = pos / 10000^(2i/d_model),其中pos是位置索引,i是维度索引
        # 1. 生成位置索引:(max_len, 1) → [0,1,...,max_len-1]
        position = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
        # 2. 生成分母项:10000^(2i/num_hiddens) → 避免数值过大,用pow实现
        div_term = torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        # 3. 计算X:(max_len, num_hiddens/2) → 每个位置对应所有偶数维度的分母
        X = position / div_term
        
        # 步骤3:填充位置编码矩阵P
        # 偶数维度(0,2,4...)用正弦函数
        self.P[:, :, 0::2] = torch.sin(X)
        # 奇数维度(1,3,5...)用余弦函数
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        """
        前向传播:将位置编码加到输入序列上(仅加有效长度的编码)
        参数:
            X: 输入序列(词向量),形状 (batch_size, seq_len, num_hiddens)
        返回:
            X + 位置编码(经过Dropout)
        """
        # 将位置编码加到输入上:仅取前X.shape[1]个位置(适配当前序列长度)
        # to(X.device):确保位置编码与输入在同一设备(CPU/GPU)
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        # Dropout防止过拟合
        return self.dropout(X)

# ===================== 3. 测试并可视化位置编码 =====================
# 超参数:编码维度=32,序列长度=60
encoding_dim, num_steps = 32, 60
# 初始化位置编码(dropout=0,不丢弃任何值)
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()  # 评估模式(关闭Dropout)

# 构造输入:全0序列(模拟词向量为0的情况,仅看位置编码效果),形状 (1, 60, 32)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))

# 提取有效长度的位置编码矩阵P:(1, 60, 32)
P = pos_encoding.P[:, :X.shape[1], :]

# 可视化位置编码:绘制前60个位置的第6~9维度的值
# x轴:位置索引(0~59),y轴:编码值,图例:维度6~9
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

# ===================== 4. 二进制演示:位置编码的周期性 =====================
# 打印0~7的二进制(演示位置编码的二进制规律,解释编码的周期性)
for i in range(8):
    # f-string格式化:>03b表示补零到3位二进制
    print(f'{i}的二进制是:{i:>03b}')

# ===================== 5. 可视化位置编码热力图 =====================
# 调整P的形状:(1, 60, 32) → (1, 1, 60, 32)
# 适配d2l.show_heatmaps的输入要求:(batch_size, num_heads, seq_len, encoding_dim)
P = P[0, :, :].unsqueeze(0).unsqueeze(0)

# 绘制热力图:x轴=编码维度,y轴=位置,颜色深浅表示编码值大小
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')

小结

在自注意力中,查询、键和值都来自同一组输入。

卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

为了使用序列的顺序信息,可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息。

10.7 Transformer

一、Transformer 核心架构

Transformer 由编码器(Encoder) 和解码器(Decoder) 两部分组成,整体结构如下:1.1 编码器(Encoder)编码器由 N 个相同的编码器层 堆叠而成(论文中 N=6),每个编码器层包含:多头自注意力层(Multi-Head Self-Attention):建模输入序列内部的依赖关系;前馈神经网络(FFN):对每个位置的特征独立进行非线性变换;残差连接 + 层归一化:防止梯度消失,加速训练。1.2 解码器(Decoder)解码器由 N 个相同的解码器层 堆叠而成(论文中 N=6),每个解码器层包含:掩码多头自注意力层(Masked Multi-Head Self-Attention):建模输出序列内部的依赖,屏蔽未来位置;编码器 - 解码器注意力层(Encoder-Decoder Attention):建模输入与输出序列的关联;前馈神经网络(FFN):同编码器;残差连接 + 层归一化:同编码器。1.3 其他核心组件词嵌入(Embedding):将词索引转为词向量,编码器 / 解码器各有独立的嵌入层;位置编码(Positional Encoding):为序列注入位置信息(自注意力无位置感知);输出层:线性层 + Softmax,将解码器输出映射为目标词汇表概率。

二、Transformer 核心公式

2.1 编码器层
\(\(\begin{align*} \tilde{X} &= \text{MultiHeadSelfAttention}(X, X, X, \text{valid\_lens}) + X \\ X_1 &= \text{LayerNorm}(\tilde{X}) \\ \tilde{X}_1 &= \text{FFN}(X_1) + X_1 \\ X_{\text{enc}} &= \text{LayerNorm}(\tilde{X}_1) \end{align*}\)\)
2.2 解码器层
\(\(\begin{align*} \tilde{Y} &= \text{MaskedMultiHeadSelfAttention}(Y, Y, Y, \text{valid\_lens}) + Y \\ Y_1 &= \text{LayerNorm}(\tilde{Y}) \\ \tilde{Y}_1 &= \text{MultiHeadAttention}(Y_1, X_{\text{enc}}, X_{\text{enc}}, \text{valid\_lens}) + Y_1 \\ Y_2 &= \text{LayerNorm}(\tilde{Y}_1) \\ \tilde{Y}_2 &= \text{FFN}(Y_2) + Y_2 \\ Y_{\text{dec}} &= \text{LayerNorm}(\tilde{Y}_2) \end{align*}\)\)
2.3 前馈神经网络
\((FFNtext{FFN}(X) = \max(0, X W_1 + b_1) W_2 + b_2\)其中 \(W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}\),\(W_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}\)(论文中 \(d_{\text{ff}}=2048\),\(d_{\text{model}}=512\)\)

代码实现

import math
import pandas as pd  # 用于处理注意力权重的缺失值填充
import torch
from torch import nn
from d2l import torch as d2l

# ===================== 1. 定义基于位置的前馈网络(FFN) =====================
#@save  # d2l库装饰器,标记为可保存的函数/类
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络(Transformer核心组件)"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        """
        参数说明:
            ffn_num_input: FFN输入维度(与模型隐藏层维度一致)
            ffn_num_hiddens: FFN隐藏层维度(通常比输入维度大)
            ffn_num_outputs: FFN输出维度(与输入维度一致,适配残差连接)
        """
        super(PositionWiseFFN, self).__init__(**kwargs)
        # 第一层线性变换:输入→隐藏层
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()  # 非线性激活
        # 第二层线性变换:隐藏层→输出(维度还原)
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        """前向传播:X → 线性1 → ReLU → 线性2"""
        return self.dense2(self.relu(self.dense1(X)))

# 测试FFN:输入维度4,隐藏层4,输出维度8
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()  # 评估模式(关闭Dropout等训练层)
# 输入:(2,3,4) → batch=2, seq_len=3, input_dim=4
# 输出第一个样本的结果(验证维度变换)
print("FFN测试输出(第一个样本):", ffn(torch.ones((2, 3, 4)))[0])

# ===================== 2. 对比层归一化(LayerNorm)和批量归一化(BatchNorm) =====================
# 初始化层归一化(归一化维度=2)
ln = nn.LayerNorm(2)
# 初始化批量归一化(归一化维度=2,1D表示序列维度)
bn = nn.BatchNorm1d(2)
# 构造测试数据:(2,2) → 2个样本,每个样本2维
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# 打印归一化结果(对比LayerNorm和BatchNorm的差异)
print('layer norm:', ln(X), '\nbatch norm:', bn(X))

# ===================== 3. 定义残差连接+层归一化(AddNorm) =====================
#@save
class AddNorm(nn.Module):
    """残差连接后进行层规范化(Transformer核心技巧)"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        """
        参数说明:
            normalized_shape: 层归一化的维度(如[100,24]表示seq_len=100, num_hiddens=24)
            dropout: dropout概率(防止过拟合)
        """
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)  # Dropout层
        self.ln = nn.LayerNorm(normalized_shape)  # 层归一化

    def forward(self, X, Y):
        """
        前向传播:残差连接(X + Dropout(Y))→ 层归一化
        参数:
            X: 原始输入(残差连接的捷径)
            Y: 子层输出(如注意力/FFN的输出)
        """
        return self.ln(self.dropout(Y) + X)

# 测试AddNorm:归一化维度[3,4],dropout=0.5
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
# 输入:X和Y均为(2,3,4) → batch=2, seq_len=3, num_hiddens=4
print("AddNorm输出形状:", add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape)

# ===================== 4. 定义Transformer编码器块 =====================
#@save
class EncoderBlock(nn.Module):
    """Transformer编码器块(单个)"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        """
        参数说明:
            key_size/query_size/value_size: 注意力的Q/K/V维度
            num_hiddens: 模型隐藏层总维度
            norm_shape: AddNorm的归一化维度
            ffn_num_input: FFN输入维度(=num_hiddens)
            ffn_num_hiddens: FFN隐藏层维度
            num_heads: 注意力头数
            dropout: dropout概率
            use_bias: 线性层是否使用偏置
        """
        super(EncoderBlock, self).__init__(**kwargs)
        # 多头自注意力层(Q=K=V=输入X)
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        # 第一个AddNorm:自注意力输出 + 残差
        self.addnorm1 = AddNorm(norm_shape, dropout)
        # 基于位置的前馈网络
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        # 第二个AddNorm:FFN输出 + 残差
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        """
        编码器块前向传播:自注意力 → AddNorm → FFN → AddNorm
        参数:
            X: 输入序列,形状 (batch_size, seq_len, num_hiddens)
            valid_lens: 有效长度(屏蔽PAD token)
        """
        # 步骤1:多头自注意力(Q=K=V=X) + 残差+层归一化
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        # 步骤2:FFN + 残差+层归一化
        return self.addnorm2(Y, self.ffn(Y))

# 测试编码器块:
# 输入X:(2,100,24) → batch=2, seq_len=100, num_hiddens=24
# valid_lens:[3,2] → batch0前3个位置有效,batch1前2个位置有效
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
print("编码器块输出形状:", encoder_blk(X, valid_lens).shape)

# ===================== 5. 定义完整的Transformer编码器 =====================
#@save
class TransformerEncoder(d2l.Encoder):
    """Transformer编码器(堆叠多个编码器块)"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        """
        参数说明:
            vocab_size: 源语言词汇表大小
            num_layers: 编码器块的堆叠层数
            其余参数:同EncoderBlock
        """
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens  # 模型隐藏层维度
        self.embedding = nn.Embedding(vocab_size, num_hiddens)  # 词嵌入层
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)  # 位置编码
        self.blks = nn.Sequential()  # 堆叠编码器块
        # 循环添加num_layers个编码器块
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        """
        编码器前向传播:词嵌入 → 位置编码 → 堆叠编码器块
        参数:
            X: 源语言序列(词索引),形状 (batch_size, seq_len)
            valid_lens: 有效长度
        """
        # 步骤1:词嵌入 + 缩放(乘以√num_hiddens,平衡词嵌入和位置编码的幅值)
        # 步骤2:添加位置编码
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        # 初始化注意力权重存储(用于后续可视化)
        self.attention_weights = [None] * len(self.blks)
        # 步骤3:逐层通过编码器块
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            # 保存每一层的注意力权重
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

# 测试完整编码器:
# 词汇表大小200,隐藏层24,层数2,头数8
encoder = TransformerEncoder(
    200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
# 输入:(2,100) → batch=2, seq_len=100(词索引序列)
print("完整编码器输出形状:", encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)

# ===================== 6. 定义Transformer解码器块 =====================
class DecoderBlock(nn.Module):
    """解码器中第i个块(包含掩码自注意力+编码器-解码器注意力)"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        """
        参数说明:
            i: 解码器块的索引(用于管理预测阶段的状态)
            其余参数:同编码器块
        """
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i  # 块索引
        # 掩码多头自注意力(解码器自注意力,屏蔽未来位置)
        self.attention1 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)  # 自注意力的残差+归一化
        # 编码器-解码器注意力(Q=解码器输出,K/V=编码器输出)
        self.attention2 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)  # 交叉注意力的残差+归一化
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)  # FFN
        self.addnorm3 = AddNorm(norm_shape, dropout)  # FFN的残差+归一化

    def forward(self, X, state):
        """
        解码器块前向传播:
        掩码自注意力 → AddNorm → 编码器-解码器注意力 → AddNorm → FFN → AddNorm
        参数:
            X: 解码器输入序列,形状 (batch_size, seq_len, num_hiddens)
            state: 解码器状态,包含(编码器输出, 编码器有效长度, 解码器历史状态)
        """
        # 从state中拆分:编码器输出、编码器有效长度
        enc_outputs, enc_valid_lens = state[0], state[1]
        
        # 处理解码器自注意力的key/value(适配预测阶段的逐词解码)
        # 训练阶段:state[2][self.i]为None,key_values=当前X
        # 预测阶段:key_values=历史输出+当前X(拼接已解码的词)
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values  # 更新解码器历史状态
        
        # 处理解码器有效长度(训练/预测阶段不同)
        if self.training:
            batch_size, num_steps, _ = X.shape
            # 训练阶段:dec_valid_lens=[1,2,...,num_steps](屏蔽未来位置)
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None  # 预测阶段无需手动设置(掩码已处理)

        # 步骤1:掩码多头自注意力(Q=X,K/V=key_values)
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)  # 残差+归一化
        
        # 步骤2:编码器-解码器注意力(Q=Y,K/V=编码器输出)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)  # 残差+归一化
        
        # 步骤3:FFN + 残差+归一化
        return self.addnorm3(Z, self.ffn(Z)), state

# 测试解码器块:
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))  # 解码器输入
# 初始化解码器状态:(编码器输出, 编码器有效长度, 解码器历史状态)
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
print("解码器块输出形状:", decoder_blk(X, state)[0].shape)

# ===================== 7. 定义完整的Transformer解码器 =====================
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        """
        参数说明:
            vocab_size: 目标语言词汇表大小
            num_layers: 解码器块的堆叠层数
            其余参数:同解码器块
        """
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)  # 目标语言词嵌入
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)  # 位置编码
        self.blks = nn.Sequential()  # 堆叠解码器块
        # 循环添加num_layers个解码器块
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)  # 输出层(映射到词汇表)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        """初始化解码器状态:(编码器输出, 编码器有效长度, 解码器历史状态列表)"""
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        """
        解码器前向传播:词嵌入 → 位置编码 → 堆叠解码器块 → 输出层
        """
        # 词嵌入 + 缩放 + 位置编码(同编码器)
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        # 初始化注意力权重存储:[自注意力权重, 编码器-解码器注意力权重]
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        # 逐层通过解码器块
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # 保存解码器自注意力权重
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            # 保存编码器-解码器注意力权重
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        # 输出层映射到词汇表
        return self.dense(X), state

    @property
    def attention_weights(self):
        """返回注意力权重(用于可视化)"""
        return self._attention_weights

# ===================== 8. 训练Transformer(英法翻译任务) =====================
# 超参数设置(简化版,适配小数据集)
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()  # 学习率、训练轮数、设备
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4  # FFN参数、注意力头数
key_size, query_size, value_size = 32, 32, 32  # 注意力Q/K/V维度
norm_shape = [32]  # 层归一化维度

# 加载英法翻译数据集:返回数据迭代器、源语言词汇表(英语)、目标语言词汇表(法语)
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

# 初始化编码器
encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
# 初始化解码器
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
# 组合编码器-解码器为完整的Seq2Seq模型
net = d2l.EncoderDecoder(encoder, decoder)

# 训练模型(d2l封装的Seq2Seq训练函数)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

# ===================== 9. 测试翻译效果 =====================
# 测试句子对:英语→法语
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']

# 逐句测试翻译并计算BLEU分数
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation},  bleu {d2l.bleu(translation, fra, k=2):.3f}')

# ===================== 10. 可视化编码器注意力权重 =====================
# 拼接所有编码器层的注意力权重:(num_layers, num_heads, batch*seq_len, num_steps)
enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
    -1, num_steps))
print("编码器注意力权重形状:", enc_attention_weights.shape)

# 绘制编码器注意力热力图(4个头,x轴=Key位置,y轴=Query位置)
d2l.show_heatmaps(
    enc_attention_weights.cpu(), xlabel='Key positions',
    ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
    figsize=(7, 3.5))

# ===================== 11. 可视化解码器注意力权重 =====================
# 处理解码器注意力权重:展平为2D列表(填充缺失值)
dec_attention_weights_2d = [head[0].tolist()
                            for step in dec_attention_weight_seq
                            for attn in step for blk in attn for head in blk]
# 用pandas填充缺失值(预测阶段长度不足的位置补0)
dec_attention_weights_filled = torch.tensor(
    pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
# 重塑形状:(seq_len, 2, num_layers, num_heads, num_steps)
dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
# 拆分:解码器自注意力权重、编码器-解码器注意力权重
dec_self_attention_weights, dec_inter_attention_weights = \
    dec_attention_weights.permute(1, 2, 3, 0, 4)
print("解码器自注意力权重形状:", dec_self_attention_weights.shape)
print("编码器-解码器注意力权重形状:", dec_inter_attention_weights.shape)

# 绘制解码器自注意力热力图(包含序列开始符)
d2l.show_heatmaps(
    dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
    xlabel='Key positions', ylabel='Query positions',
    titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))

# 绘制编码器-解码器注意力热力图
d2l.show_heatmaps(
    dec_inter_attention_weights, xlabel='Key positions',
    ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
    figsize=(7, 3.5))
posted @ 2025-12-14 15:31  Morphis‘  阅读(84)  评论(0)    收藏  举报