注意力汇聚
🧪 1. 生成数据集
我们首先构造一个非线性回归问题:
\[y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon, \quad \epsilon \sim \mathcal{N}(0, 0.5)
\]
✅ 实现代码:
import torch
import matplotlib.pyplot as plt
from d2l import torch as d2l
n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本
def f(x):
return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 带噪声的输出
x_test = torch.arange(0, 5, 0.1) # 测试输入
y_truth = f(x_test) # 真实输出
n_test = len(x_test) # 测试样本数
x_train是排序过的随机数,模拟真实世界中有序的输入;y_train加入了正态分布噪声,使任务更具挑战性。
📈 2. 画图函数
用于可视化训练数据、真实函数和预测结果。
✅ 实现代码:
def plot_kernel_reg(y_hat):
d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
xlim=[0, 5], ylim=[-1, 5])
d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
🧮 3. 平均汇聚(Average Pooling)
最简单的估计器:忽略输入 $ x $,只用所有训练输出的均值作为预测。
✅ 实现代码:
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
使用
torch.repeat_interleave将标量重复为与测试集长度相同的张量。
🔍 4. 非参数注意力汇聚(Nadaraya-Watson 核回归)
基于高斯核的距离加权平均:
\[f(x) = \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i
\]
✅ 实现代码:
# 构造查询与键之间的距离矩阵
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
attention_weights = torch.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# 注意力加权求和
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)
🔥 可视化注意力权重:
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
⚙️ 5. 带参数注意力汇聚(Parametric Attention Pooling)
在距离计算中引入可学习参数 $ w $:
\[f(x) = \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i
\]
✅ 定义模型:
class NWKernelRegression(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.ones(1))
def forward(self, queries, keys, values):
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = torch.softmax(
-((queries - keys) * self.w) ** 2 / 2, dim=1)
return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).squeeze()
📊 6. 准备训练数据(排除自身)
每个训练样本不考虑自己,仅关注其他样本:
X_tile = x_train.repeat((n_train, 1))
Y_tile = y_train.repeat((n_train, 1))
# 构建键值对,排除自己
keys = X_tile[~torch.eye(n_train, dtype=torch.bool)].reshape((n_train, -1))
values = Y_tile[~torch.eye(n_train, dtype=torch.bool)].reshape((n_train, -1))
🏋️ 7. 训练模型
使用均方误差损失(MSE)和 SGD 进行训练:
net = NWKernelRegression()
loss = torch.nn.MSELoss(reduction='none')
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
optimizer.zero_grad()
y_hat = net(x_train, keys, values)
l = loss(y_hat, y_train).sum()
l.backward()
optimizer.step()
print(f'epoch {epoch+1}, loss {float(l):.6f}')
animator.add(epoch + 1, float(l))
📉 8. 预测与可视化
使用训练好的模型进行预测并绘图:
keys_test = x_train.repeat((n_test, 1))
values_test = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys_test, values_test)
plot_kernel_reg(y_hat)
🔎 9. 可视化注意力权重(训练后)
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
你会发现:
- 注意力权重更加集中(因为 $ w $ 被优化)
- 预测曲线更“锐利”,不如非参数模型平滑
📌 10. 总结
| 方法 | 是否参数化 | 是否灵活 | 平滑度 | 是否可微 |
|---|---|---|---|---|
| 平均汇聚 | 否 | 否 | 最差 | 是 |
| Nadaraya-Watson 核回归 | 否 | 中等 | 好 | 是 |
| 带参数注意力汇聚 | 是 | 强 | 一般 | 是 |
结论:带参数注意力机制通过引入可学习参数 $ w $,提高了模型灵活性,但也可能导致过拟合或局部震荡。它为现代注意力机制(如 Transformer)奠定了基础。

浙公网安备 33010602011771号