机器学习基础(十四):Transformer与注意力机制
一、引言
上一篇我们学习了 RNN 和 LSTM,它们通过循环结构处理序列数据。但 RNN 有一个根本性的瓶颈:
必须按顺序处理,无法并行
这导致训练长序列时效率极低。2017年,Google 发表论文《Attention Is All You Need》,提出了 Transformer 架构——完全基于注意力机制,彻底抛弃了循环结构。
本文目标:
- 理解注意力机制的直观原理
- 掌握 Self-Attention 的完整公式
- 从零搭建 Transformer
- 实现一个机器翻译模型
二、注意力机制的本质
2.1 直观理解
想象你在读这句话:
"小明把苹果给了小红,她很高兴。"
读到"她"时,你的大脑会自动"关注"到"小红"——这就是注意力。
核心思想: 在处理序列的每个位置时,动态地决定应该"看"哪些其他位置。
2.2 查询-键-值(Query-Key-Value)
注意力机制借鉴了数据库查询的思想:
| 概念 | 类比 | 作用 |
|---|---|---|
| Query | 搜索关键词 | 我想找什么? |
| Key | 文档标签 | 每个位置的内容是什么? |
| Value | 文档内容 | 实际要获取的信息 |
计算过程:
- 用 Query 和每个 Key 计算相似度(点积)
- 用 softmax 转换为注意力权重
- 用权重对 Value 做加权求和
# 伪代码
def attention(query, keys, values):
scores = query @ keys.T # 相似度:[1, seq_len]
weights = softmax(scores) # 注意力权重:[1, seq_len]
output = weights @ values # 加权求和:[1, dim]
return output
三、Self-Attention 自注意力
3.1 什么是自注意力?
Self-Attention:序列中的每个位置都关注序列中的所有位置(包括自己)。
输入一个序列,输出同样长度的序列,每个输出都是对所有输入的加权组合。
3.2 完整公式推导
给定输入序列 \(X \in \mathbb{R}^{n \times d_{model}}\)(\(n\) 个词,每个 \(d_{model}\) 维)
Step 1: 生成 Q/K/V
其中 \(W^Q, W^K \in \mathbb{R}^{d_{model} \times d_k}\),\(W^V \in \mathbb{R}^{d_{model} \times d_v}\)
Step 2: 计算注意力分数
为什么要除以 \(\sqrt{d_k}\)?
当 \(d_k\) 很大时,点积的数值会很大,导致 softmax 进入梯度饱和区。除以 \(\sqrt{d_k}\) 可以稳定梯度。
3.3 PyTorch 实现
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, d_model=512):
super().__init__()
self.d_model = d_model
self.d_k = d_model # 简化:d_k = d_model
# 三个线性投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
"""
x: [batch, seq_len, d_model]
mask: [batch, seq_len, seq_len] 用于遮挡未来信息
"""
batch, seq_len, _ = x.shape
# 生成 Q/K/V
Q = self.W_q(x) # [batch, seq_len, d_model]
K = self.W_k(x)
V = self.W_v(x)
# 计算注意力分数: Q @ K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# [batch, seq_len, seq_len]
# 应用 mask(Decoder 中需要)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax 得到注意力权重
attn_weights = torch.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V) # [batch, seq_len, d_model]
return output, attn_weights
# 测试
x = torch.randn(2, 10, 512) # batch=2, seq_len=10, d_model=512
attn = SelfAttention()
out, weights = attn(x)
print(f"输入: {x.shape}")
print(f"输出: {out.shape}")
print(f"注意力权重: {weights.shape}") # [2, 10, 10]
四、Multi-Head Attention 多头注意力
4.1 为什么需要多头?
单个注意力可能只关注一种关系。多头允许模型同时关注不同方面的信息:
- 一个头关注语法关系
- 一个头关注语义关系
- 一个头关注指代关系
4.2 多头机制
将 Q/K/V 投影到 \(h\) 个低维空间,分别做注意力,再拼接:
其中 \(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)
4.3 PyTorch 实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 线性投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性投影并分头
# [batch, seq, d_model] -> [batch, seq, heads, d_k] -> [batch, heads, seq, d_k]
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V) # [batch, heads, seq, d_k]
# 3. 拼接并线性变换
context = context.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(context)
return output, attn_weights
# 测试
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out, weights = mha(x, x, x)
print(f"多头注意力输出: {out.shape}") # [2, 10, 512]
print(f"注意力权重: {weights.shape}") # [2, 8, 10, 10]
五、Transformer 整体架构
5.1 Encoder-Decoder 结构
┌─────────────────────────────────────────────────────────────┐
│ Transformer │
├─────────────────────────────┬───────────────────────────────┤
│ Encoder │ Decoder │
│ ┌─────────────────────┐ │ ┌─────────────────────┐ │
│ │ Input Embedding │ │ │ Output Embedding │ │
│ │ + Positional Enc │ │ │ + Positional Enc │ │
│ └──────────┬──────────┘ │ └──────────┬──────────┘ │
│ ▼ │ ▼ │
│ ┌─────────────────────┐ │ ┌─────────────────────┐ │
│ │ Multi-Head Attn │ │ │ Masked MHA │ │
│ │ + Add & Norm │ │ │ + Add & Norm │ │
│ └──────────┬──────────┘ │ └──────────┬──────────┘ │
│ ▼ │ ▼ │
│ ┌─────────────────────┐ │ ┌─────────────────────┐ │
│ │ Feed Forward │ │ │ Multi-Head Attn │ │
│ │ + Add & Norm │ │ │ (Q from decoder, │ │
│ └──────────┬──────────┘ │ │ K/V from encoder) │ │
│ │ │ │ + Add & Norm │ │
│ │ (N×堆叠) │ └──────────┬──────────┘ │
│ ▼ │ ▼ │
│ ┌─────────────────────┐ │ ┌─────────────────────┐ │
│ │ Encoder Output │───┐│ │ Feed Forward │ │
│ └─────────────────────┘ ││ │ + Add & Norm │ │
│ ││ └──────────┬──────────┘ │
└────────────────────────────│┘ │ │
│ │ (N×堆叠) │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Linear + Softmax │ │
│ │ -> Output Prob │ │
│ └─────────────────────┘ │
└────────────────────────────────┘
5.2 组件详解
| 组件 | 作用 |
|---|---|
| Input Embedding | 将词索引转为向量 |
| Positional Encoding | 注入位置信息(Transformer 本身没有顺序概念) |
| Multi-Head Attention | 捕捉全局依赖关系 |
| Add & Norm | 残差连接 + Layer Normalization |
| Feed Forward | 两个线性层夹 ReLU:\(\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2\) |
| Masked MHA | Decoder 中遮挡未来位置,防止偷看答案 |
六、位置编码(Positional Encoding)
6.1 为什么需要位置编码?
Self-Attention 是位置无关的——打乱输入顺序,输出不变。但语言中顺序很重要!
6.2 正弦位置编码
Transformer 使用正弦/余弦函数生成位置编码:
优点:
- 可以处理任意长度的序列
- 相对位置可以线性表示:\(PE_{pos+k}\) 可以用 \(PE_{pos}\) 线性表示
6.3 PyTorch 实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 创建位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 分母项
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 正弦和余弦
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为 buffer(不参与训练)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""
x: [batch, seq_len, d_model]
"""
return x + self.pe[:, :x.size(1), :]
# 测试
pe = PositionalEncoding(d_model=512)
x = torch.randn(2, 10, 512)
out = pe(x)
print(f"位置编码后: {out.shape}")
七、Layer Norm & 残差连接
7.1 残差连接(Residual Connection)
解决深层网络的梯度消失问题,允许训练更深的网络。
7.2 Layer Normalization
对每个样本的所有特征做归一化:
与 BatchNorm 不同,LayerNorm 在序列维度上归一化,更适合变长序列。
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差 + LayerNorm
attn_out, _ = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
# 前馈 + 残差 + LayerNorm
ff_out = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_out))
return x
八、Transformer vs RNN
| 特性 | RNN/LSTM | Transformer |
|---|---|---|
| 并行性 | ❌ 顺序处理 | ✅ 完全并行 |
| 长距离依赖 | ❌ 梯度消失 | ✅ 直接连接 |
| 计算复杂度 | \(O(n)\) | \(O(n^2)\)(注意力矩阵) |
| 位置信息 | 天然有序 | 需要位置编码 |
| 训练速度 | 慢 | 快(可并行) |
| 显存占用 | 低 | 高(注意力矩阵) |
关键突破: Transformer 用 \(O(n^2)\) 的计算换取了完全并行和直接长距离依赖。
九、完整 Transformer 实现
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
max_len=5000,
dropout=0.1
):
super().__init__()
self.d_model = d_model
# Embedding + Positional Encoding
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
# Encoder
self.encoder_layers = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
# Decoder
self.decoder_layers = nn.ModuleList([
DecoderBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
# 输出层
self.output_layer = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def encode(self, src, src_mask=None):
"""编码器前向传播"""
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
"""解码器前向传播"""
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.decoder_layers:
x = layer(x, memory, src_mask, tgt_mask)
return x
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
memory = self.encode(src, src_mask)
output = self.decode(tgt, memory, src_mask, tgt_mask)
return self.output_layer(output)
class DecoderBlock(nn.Module):
"""带 Masked MHA 和 Cross-Attention 的解码器块"""
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.masked_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, memory, src_mask=None, tgt_mask=None):
# Masked Self-Attention
attn_out, _ = self.masked_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_out))
# Cross-Attention (Q from decoder, K/V from encoder)
attn_out, _ = self.cross_attn(x, memory, memory, src_mask)
x = self.norm2(x + self.dropout(attn_out))
# Feed Forward
ff_out = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_out))
return x
def create_look_ahead_mask(size):
"""创建上三角 mask,防止看到未来信息"""
mask = torch.triu(torch.ones(size, size), diagonal=1)
return mask == 0 # 下三角为 True
# 测试
src_vocab_size = 10000
tgt_vocab_size = 10000
model = Transformer(src_vocab_size, tgt_vocab_size)
src = torch.randint(0, src_vocab_size, (2, 20)) # [batch, src_len]
tgt = torch.randint(0, tgt_vocab_size, (2, 15)) # [batch, tgt_len]
tgt_mask = create_look_ahead_mask(15).unsqueeze(0).unsqueeze(0) # [1, 1, 15, 15]
output = model(src, tgt, tgt_mask=tgt_mask)
print(f"Transformer 输出: {output.shape}") # [2, 15, tgt_vocab_size]
十、实战:机器翻译
10.1 数据准备(使用 torchtext)
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
# 简化的词表类
class Vocab:
def __init__(self):
self.word2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
self.idx2word = {0: '<pad>', 1: '<sos>', 2: '<eos>', 3: '<unk>'}
self.count = 4
def build_vocab(self, sentences, min_freq=2):
from collections import Counter
counter = Counter()
for sent in sentences:
counter.update(sent.split())
for word, freq in counter.items():
if freq >= min_freq and word not in self.word2idx:
self.word2idx[word] = self.count
self.idx2word[self.count] = word
self.count += 1
def encode(self, sentence):
tokens = ['<sos>'] + sentence.split() + ['<eos>']
return [self.word2idx.get(t, 3) for t in tokens]
def decode(self, indices):
words = [self.idx2word.get(i, '<unk>') for i in indices]
return ' '.join(words)
# 示例数据(实际使用 WMT14 等数据集)
src_sentences = [
"hello world",
"how are you",
"i love machine learning",
"transformer is powerful",
"attention is all you need"
]
tgt_sentences = [
"你好 世界",
"你 好吗",
"我 喜欢 机器 学习",
"transformer 很 强大",
"注意力 就是 你 需要 的 一切"
]
# 构建词表
src_vocab = Vocab()
tgt_vocab = Vocab()
src_vocab.build_vocab(src_sentences)
tgt_vocab.build_vocab(tgt_sentences)
print(f"源语言词表大小: {len(src_vocab.word2idx)}")
print(f"目标语言词表大小: {len(tgt_vocab.word2idx)}")
10.2 训练循环
def train_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
# tgt_input: 去掉最后一个词(输入)
# tgt_label: 去掉第一个词(目标)
tgt_input = tgt[:, :-1]
tgt_label = tgt[:, 1:]
# 创建 mask
tgt_mask = create_look_ahead_mask(tgt_input.size(1)).to(device)
# 前向传播
optimizer.zero_grad()
output = model(src, tgt_input, tgt_mask=tgt_mask)
# 计算损失(忽略 <pad>)
output = output.view(-1, output.size(-1))
tgt_label = tgt_label.contiguous().view(-1)
loss = criterion(output, tgt_label)
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
"""贪心解码进行翻译"""
model.eval()
# 编码源句子
src = torch.tensor([src_vocab.encode(src_sentence)]).to(device)
memory = model.encode(src)
# 自回归生成
tgt_indices = [1] # <sos>
with torch.no_grad():
for _ in range(max_len):
tgt = torch.tensor([tgt_indices]).to(device)
tgt_mask = create_look_ahead_mask(len(tgt_indices)).to(device)
output = model.decode(tgt, memory, tgt_mask=tgt_mask)
output = model.output_layer(output)
# 取最后一个位置的预测
next_token = output[0, -1].argmax().item()
tgt_indices.append(next_token)
if next_token == 2: # <eos>
break
return tgt_vocab.decode(tgt_indices)
# 完整训练示例
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(
src_vocab_size=len(src_vocab.word2idx),
tgt_vocab_size=len(tgt_vocab.word2idx),
d_model=256,
num_heads=4,
num_encoder_layers=2,
num_decoder_layers=2,
d_ff=512
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略 <pad>
# 准备数据
def collate_fn(batch):
src, tgt = zip(*batch)
src = pad_sequence([torch.tensor(s) for s in src], batch_first=True, padding_value=0)
tgt = pad_sequence([torch.tensor(t) for t in tgt], batch_first=True, padding_value=0)
return src, tgt
data = [(src_vocab.encode(s), tgt_vocab.encode(t))
for s, t in zip(src_sentences, tgt_sentences)]
dataloader = DataLoader(data, batch_size=2, shuffle=True, collate_fn=collate_fn)
# 训练
for epoch in range(100):
loss = train_epoch(model, dataloader, optimizer, criterion, device)
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
# 测试翻译
test_sentence = "hello world"
translation = translate(model, test_sentence, src_vocab, tgt_vocab, device)
print(f"\n翻译结果:")
print(f"英文: {test_sentence}")
print(f"中文: {translation}")
十一、Transformer 的变体与应用
11.1 经典变体
| 模型 | 结构 | 特点 | 应用 |
|---|---|---|---|
| BERT | Encoder-only | 双向注意力,预训练 + 微调 | 文本分类、NER、问答 |
| GPT | Decoder-only | 自回归生成 | 文本生成、对话 |
| T5 | Encoder-Decoder | 统一文本到文本框架 | 翻译、摘要、问答 |
| Vision Transformer | Encoder-only | 图像分块作为序列 | 图像分类 |
11.2 注意力可视化
def visualize_attention(attention_weights, src_tokens, tgt_tokens):
"""
可视化注意力权重
attention_weights: [num_heads, tgt_len, src_len]
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
if i < attention_weights.size(0):
attn = attention_weights[i].cpu().detach().numpy()
im = ax.imshow(attn, cmap='viridis', aspect='auto')
ax.set_title(f'Head {i+1}')
ax.set_xticks(range(len(src_tokens)))
ax.set_yticks(range(len(tgt_tokens)))
ax.set_xticklabels(src_tokens, rotation=90)
ax.set_yticklabels(tgt_tokens)
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.savefig('attention_visualization.png', dpi=150)
plt.show()
# 使用示例
# visualize_attention(attn_weights[0], src_words, tgt_words)
十二、总结
12.1 核心要点
Transformer
├── 注意力机制
│ ├── Self-Attention: 序列内部的全局依赖
│ ├── Multi-Head: 多视角并行关注
│ └── Scaled Dot-Product: QK^T / sqrt(d_k)
├── 架构设计
│ ├── Encoder: 双向理解
│ ├── Decoder: 自回归生成
│ └── Cross-Attention: 源目标信息融合
├── 关键技巧
│ ├── Positional Encoding: 注入位置信息
│ ├── Residual Connection: 缓解梯度消失
│ └── Layer Normalization: 稳定训练
└── 优势
├── 完全并行: 训练速度快
└── 长距离依赖: 直接连接
12.2 学习路径回顾
机器学习基础系列
├── (一) 绪论
├── (二) 线性回归
├── (三) 逻辑回归
├── (四) 决策树
├── (五) 集成学习
├── (六) 支持向量机
├── (七) 聚类
├── (八) 降维
├── (九) PyTorch入门
├── (十) 卷积神经网络CNN
├── (十一) 过拟合与正则化
├── (十二) 优化器大全
├── (十三) 循环神经网络RNN与LSTM
└── (十四) Transformer与注意力机制 ← 你在这里
└── 下一步: BERT/GPT预训练模型
12.3 下一步学习建议
- 深入理解 BERT:双向编码器的预训练与微调
- 探索 GPT 系列:自回归语言模型的演进
- Vision Transformer:注意力在 CV 领域的应用
- 高效 Transformer:稀疏注意力、线性注意力等优化
附录:Self-Attention 速查表
| 符号 | 含义 | 维度 |
|---|---|---|
| \(X\) | 输入序列 | \([n, d_{model}]\) |
| \(Q\) | Query 矩阵 | \([n, d_k]\) |
| \(K\) | Key 矩阵 | \([n, d_k]\) |
| \(V\) | Value 矩阵 | \([n, d_v]\) |
| \(W^Q, W^K, W^V\) | 投影矩阵 | 学习参数 |
| \(\text{Attention}\) | 注意力输出 | \([n, d_v]\) |
核心公式:
本文代码可在 GitHub 获取:https://github.com/paywqiao-max/ML-Basics

浙公网安备 33010602011771号