Transformer完整实现及注释
主要组件:
- Multi-Head Self-Attention (多头自注意力)
- Position Encoding (位置编码)
- Feed Forward Network (前馈神经网络)
- Encoder/Decoder Layer (编码器/解码器层)
- Complete Transformer Model (完整模型)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制 (Multi-Head Self-Attention)
核心思想:
- 将输入投影到Q、K、V三个矩阵
- 计算注意力权重:Attention(Q,K,V) = softmax(QK^T/√d_k)V
- 多个注意力头并行计算,捕获不同位置和表示子空间的信息
"""
def __init__(self, d_model, n_heads, dropout=0.1):
"""
Args:
d_model: 模型维度 (通常512或768)
n_heads: 注意力头数 (通常8或12)
dropout: dropout概率
"""
super(MultiHeadAttention, self).__init__()
# 确保d_model能被n_heads整除
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# 线性变换层:将输入投影到Q、K、V
# 注意:这里用一个大矩阵同时计算所有头的QKV,更高效
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
# 输出投影层
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# 缩放因子,防止softmax饱和
self.scale = math.sqrt(self.d_k)
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch_size, seq_len, d_model]
key: [batch_size, seq_len, d_model]
value: [batch_size, seq_len, d_model]
mask: [batch_size, seq_len, seq_len] 注意力掩码
Returns:
output: [batch_size, seq_len, d_model]
attention_weights: [batch_size, n_heads, seq_len, seq_len]
"""
batch_size, seq_len, d_model = query.size()
# 1. 线性变换得到Q、K、V
Q = self.w_q(query) # [batch_size, seq_len, d_model]
K = self.w_k(key) # [batch_size, seq_len, d_model]
V = self.w_v(value) # [batch_size, seq_len, d_model]
# 2. 重塑为多头形式
# 注意:Q, K, V的序列长度可能不同(特别是在交叉注意力中)
q_seq_len = query.size(1)
k_seq_len = key.size(1)
v_seq_len = value.size(1)
Q = Q.view(batch_size, q_seq_len, self.n_heads, self.d_k)
K = K.view(batch_size, k_seq_len, self.n_heads, self.d_k)
V = V.view(batch_size, v_seq_len, self.n_heads, self.d_k)
# 转置以便矩阵乘法: [batch_size, n_heads, seq_len, d_k]
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 3. 计算注意力
attention_output, attention_weights = self.scaled_dot_product_attention(
Q, K, V, mask, self.scale
)
# 4. 拼接多头结果
# [batch_size, n_heads, q_seq_len, d_k] -> [batch_size, q_seq_len, n_heads, d_k]
attention_output = attention_output.transpose(1, 2).contiguous()
# [batch_size, q_seq_len, n_heads, d_k] -> [batch_size, q_seq_len, d_model]
attention_output = attention_output.view(batch_size, q_seq_len, d_model)
# 5. 输出投影
output = self.w_o(attention_output)
return output, attention_weights
def scaled_dot_product_attention(self, Q, K, V, mask, scale):
"""
缩放点积注意力核心计算
公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V
"""
# 计算注意力分数:QK^T
# [batch_size, n_heads, seq_len, d_k] × [batch_size, n_heads, d_k, seq_len]
# = [batch_size, n_heads, seq_len, seq_len]
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# 应用掩码(如果提供)
if mask is not None:
# 将掩码位置的分数设为很小的负数,softmax后接近0
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
# 计算注意力权重
attention_weights = F.softmax(attention_scores, dim=-1)
# 只在训练时应用dropout
if self.training:
attention_weights = self.dropout(attention_weights)
# 应用注意力权重到V
# [batch_size, n_heads, seq_len, seq_len] × [batch_size, n_heads, seq_len, d_k]
# = [batch_size, n_heads, seq_len, d_k]
attention_output = torch.matmul(attention_weights, V)
return attention_output, attention_weights
class PositionalEncoding(nn.Module):
"""
位置编码 (Positional Encoding)
由于Transformer没有循环或卷积结构,需要显式地给序列添加位置信息
使用sin/cos函数生成固定的位置编码
公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""
def __init__(self, d_model, max_seq_len=5000):
"""
Args:
d_model: 模型维度
max_seq_len: 支持的最大序列长度
"""
super(PositionalEncoding, self).__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_seq_len, d_model)
# 位置索引 [0, 1, 2, ..., max_seq_len-1]
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# 计算除数项:10000^(2i/d_model)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
# 应用sin和cos
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置用cos
# 添加batch维度并注册为buffer(不参与梯度更新)
pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
x + positional_encoding: [batch_size, seq_len, d_model]
"""
# 取出对应长度的位置编码并加到输入上
seq_len = x.size(1)
return x + self.pe[:, :seq_len, :]
class FeedForward(nn.Module):
"""
前馈神经网络 (Feed Forward Network)
结构:Linear -> ReLU -> Linear
通常中间层维度是输入的4倍(如512->2048->512)
"""
def __init__(self, d_model, d_ff, dropout=0.1):
"""
Args:
d_model: 输入/输出维度
d_ff: 中间层维度(通常是d_model的4倍)
dropout: dropout概率
"""
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
output: [batch_size, seq_len, d_model]
"""
# Linear -> ReLU -> Dropout -> Linear
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class EncoderLayer(nn.Module):
"""
Transformer编码器层
结构:
1. Multi-Head Self-Attention + 残差连接 + LayerNorm
2. Feed Forward + 残差连接 + LayerNorm
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
# Layer Normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: [batch_size, seq_len, d_model]
mask: 注意力掩码
Returns:
output: [batch_size, seq_len, d_model]
"""
# 1. Self-Attention + 残差连接 + LayerNorm
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. Feed Forward + 残差连接 + LayerNorm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
"""
Transformer解码器层
结构:
1. Masked Multi-Head Self-Attention + 残差 + LayerNorm
2. Multi-Head Cross-Attention + 残差 + LayerNorm
3. Feed Forward + 残差 + LayerNorm
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: 解码器输入 [batch_size, tgt_len, d_model]
encoder_output: 编码器输出 [batch_size, src_len, d_model]
src_mask: 源序列掩码
tgt_mask: 目标序列掩码(下三角掩码)
Returns:
output: [batch_size, tgt_len, d_model]
"""
# 1. Masked Self-Attention(防止看到未来信息)
self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(self_attn_output))
# 2. Cross-Attention(解码器attend到编码器输出)
cross_attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(cross_attn_output))
# 3. Feed Forward
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Transformer(nn.Module):
"""
完整的Transformer模型
包含:
- 输入嵌入 + 位置编码
- N层编码器
- N层解码器
- 输出线性层
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8,
n_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
"""
Args:
src_vocab_size: 源语言词汇表大小
tgt_vocab_size: 目标语言词汇表大小
d_model: 模型维度
n_heads: 注意力头数
n_layers: 编码器/解码器层数
d_ff: 前馈网络中间层维度
max_seq_len: 最大序列长度
dropout: dropout概率
"""
super(Transformer, self).__init__()
self.d_model = d_model
# 词嵌入层
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
# 编码器层
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 解码器层
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 输出投影层
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
# 参数初始化
self.init_parameters()
def init_parameters(self):
"""Xavier初始化"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_mask=None):
"""
编码器前向传播
Args:
src: 源序列 [batch_size, src_len]
src_mask: 源序列掩码
Returns:
encoder_output: [batch_size, src_len, d_model]
"""
# 词嵌入 + 位置编码
src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
src_emb = self.positional_encoding(src_emb)
src_emb = self.dropout(src_emb)
# 通过编码器层
encoder_output = src_emb
for encoder_layer in self.encoder_layers:
encoder_output = encoder_layer(encoder_output, src_mask)
return encoder_output
def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
"""
解码器前向传播
Args:
tgt: 目标序列 [batch_size, tgt_len]
encoder_output: 编码器输出 [batch_size, src_len, d_model]
src_mask: 源序列掩码
tgt_mask: 目标序列掩码
Returns:
decoder_output: [batch_size, tgt_len, d_model]
"""
# 词嵌入 + 位置编码
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.positional_encoding(tgt_emb)
tgt_emb = self.dropout(tgt_emb)
# 通过解码器层
decoder_output = tgt_emb
for decoder_layer in self.decoder_layers:
decoder_output = decoder_layer(decoder_output, encoder_output, src_mask, tgt_mask)
return decoder_output
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
完整前向传播
Args:
src: 源序列 [batch_size, src_len]
tgt: 目标序列 [batch_size, tgt_len]
src_mask: 源序列掩码
tgt_mask: 目标序列掩码
Returns:
output: [batch_size, tgt_len, tgt_vocab_size]
"""
# 编码
encoder_output = self.encode(src, src_mask)
# 解码
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
# 输出投影
output = self.output_projection(decoder_output)
return output
def create_padding_mask(seq, pad_idx=0):
"""
创建padding掩码,遮蔽padding位置
Args:
seq: [batch_size, seq_len]
pad_idx: padding token的索引
Returns:
mask: [batch_size, 1, 1, seq_len]
"""
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask
def create_look_ahead_mask(seq_len):
"""
创建下三角掩码,防止解码器看到未来信息
Args:
seq_len: 序列长度
Returns:
mask: [1, 1, seq_len, seq_len]
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0)
使用示例和训练代码
if name == "main":
# 模型参数
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
n_heads = 8
n_layers = 6
d_ff = 2048
max_seq_len = 100
# 创建模型
model = Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
d_ff=d_ff,
max_seq_len=max_seq_len
)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# 模拟数据
batch_size = 32
src_len = 20
tgt_len = 25
src = torch.randint(1, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))
# 创建掩码
src_mask = create_padding_mask(src)
tgt_mask = create_look_ahead_mask(tgt_len) & create_padding_mask(tgt)
# 前向传播
with torch.no_grad():
output = model(src, tgt, src_mask, tgt_mask)
print(f"输出形状: {output.shape}") # [batch_size, tgt_len, tgt_vocab_size]
# 简单训练循环示例
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略padding
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
for epoch in range(3):
# 前向传播
output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
# 计算损失(预测下一个token)
target = tgt[:, 1:].contiguous().view(-1)
output = output.contiguous().view(-1, tgt_vocab_size)
loss = criterion(output, target)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
print("\n=== Transformer模型实现完成 ===")

浙公网安备 33010602011771号