Transformer完整实现及注释

主要组件:

  1. Multi-Head Self-Attention (多头自注意力)
  2. Position Encoding (位置编码)
  3. Feed Forward Network (前馈神经网络)
  4. Encoder/Decoder Layer (编码器/解码器层)
  5. Complete Transformer Model (完整模型)
    """

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class MultiHeadAttention(nn.Module):
"""
多头自注意力机制 (Multi-Head Self-Attention)

核心思想:
- 将输入投影到Q、K、V三个矩阵
- 计算注意力权重:Attention(Q,K,V) = softmax(QK^T/√d_k)V
- 多个注意力头并行计算,捕获不同位置和表示子空间的信息
"""

def __init__(self, d_model, n_heads, dropout=0.1):
    """
    Args:
        d_model: 模型维度 (通常512或768)
        n_heads: 注意力头数 (通常8或12)
        dropout: dropout概率
    """
    super(MultiHeadAttention, self).__init__()
    
    # 确保d_model能被n_heads整除
    assert d_model % n_heads == 0
    
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads  # 每个头的维度
    
    # 线性变换层:将输入投影到Q、K、V
    # 注意:这里用一个大矩阵同时计算所有头的QKV,更高效
    self.w_q = nn.Linear(d_model, d_model, bias=False)
    self.w_k = nn.Linear(d_model, d_model, bias=False) 
    self.w_v = nn.Linear(d_model, d_model, bias=False)
    
    # 输出投影层
    self.w_o = nn.Linear(d_model, d_model)
    
    self.dropout = nn.Dropout(dropout)
    
    # 缩放因子,防止softmax饱和
    self.scale = math.sqrt(self.d_k)

def forward(self, query, key, value, mask=None):
    """
    Args:
        query: [batch_size, seq_len, d_model]
        key: [batch_size, seq_len, d_model] 
        value: [batch_size, seq_len, d_model]
        mask: [batch_size, seq_len, seq_len] 注意力掩码
    
    Returns:
        output: [batch_size, seq_len, d_model]
        attention_weights: [batch_size, n_heads, seq_len, seq_len]
    """
    batch_size, seq_len, d_model = query.size()
    
    # 1. 线性变换得到Q、K、V
    Q = self.w_q(query)  # [batch_size, seq_len, d_model]
    K = self.w_k(key)    # [batch_size, seq_len, d_model]
    V = self.w_v(value)  # [batch_size, seq_len, d_model]
    
    # 2. 重塑为多头形式
    # 注意:Q, K, V的序列长度可能不同(特别是在交叉注意力中)
    q_seq_len = query.size(1)
    k_seq_len = key.size(1)
    v_seq_len = value.size(1)
    
    Q = Q.view(batch_size, q_seq_len, self.n_heads, self.d_k)
    K = K.view(batch_size, k_seq_len, self.n_heads, self.d_k)
    V = V.view(batch_size, v_seq_len, self.n_heads, self.d_k)
    
    # 转置以便矩阵乘法: [batch_size, n_heads, seq_len, d_k]
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)
    
    # 3. 计算注意力
    attention_output, attention_weights = self.scaled_dot_product_attention(
        Q, K, V, mask, self.scale
    )
    
    # 4. 拼接多头结果
    # [batch_size, n_heads, q_seq_len, d_k] -> [batch_size, q_seq_len, n_heads, d_k]
    attention_output = attention_output.transpose(1, 2).contiguous()
    
    # [batch_size, q_seq_len, n_heads, d_k] -> [batch_size, q_seq_len, d_model]
    attention_output = attention_output.view(batch_size, q_seq_len, d_model)
    
    # 5. 输出投影
    output = self.w_o(attention_output)
    
    return output, attention_weights

def scaled_dot_product_attention(self, Q, K, V, mask, scale):
    """
    缩放点积注意力核心计算
    
    公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V
    """
    # 计算注意力分数:QK^T
    # [batch_size, n_heads, seq_len, d_k] × [batch_size, n_heads, d_k, seq_len]
    # = [batch_size, n_heads, seq_len, seq_len]
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
    
    # 应用掩码(如果提供)
    if mask is not None:
        # 将掩码位置的分数设为很小的负数,softmax后接近0
        attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
    
    # 计算注意力权重
    attention_weights = F.softmax(attention_scores, dim=-1)
    # 只在训练时应用dropout
    if self.training:
        attention_weights = self.dropout(attention_weights)
    
    # 应用注意力权重到V
    # [batch_size, n_heads, seq_len, seq_len] × [batch_size, n_heads, seq_len, d_k]
    # = [batch_size, n_heads, seq_len, d_k]
    attention_output = torch.matmul(attention_weights, V)
    
    return attention_output, attention_weights

class PositionalEncoding(nn.Module):
"""
位置编码 (Positional Encoding)

由于Transformer没有循环或卷积结构,需要显式地给序列添加位置信息
使用sin/cos函数生成固定的位置编码

公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""

def __init__(self, d_model, max_seq_len=5000):
    """
    Args:
        d_model: 模型维度
        max_seq_len: 支持的最大序列长度
    """
    super(PositionalEncoding, self).__init__()
    
    # 创建位置编码矩阵
    pe = torch.zeros(max_seq_len, d_model)
    
    # 位置索引 [0, 1, 2, ..., max_seq_len-1]
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    
    # 计算除数项:10000^(2i/d_model)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                       (-math.log(10000.0) / d_model))
    
    # 应用sin和cos
    pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置用sin
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置用cos
    
    # 添加batch维度并注册为buffer(不参与梯度更新)
    pe = pe.unsqueeze(0)  # [1, max_seq_len, d_model]
    self.register_buffer('pe', pe)

def forward(self, x):
    """
    Args:
        x: [batch_size, seq_len, d_model]
    
    Returns:
        x + positional_encoding: [batch_size, seq_len, d_model]
    """
    # 取出对应长度的位置编码并加到输入上
    seq_len = x.size(1)
    return x + self.pe[:, :seq_len, :]

class FeedForward(nn.Module):
"""
前馈神经网络 (Feed Forward Network)

结构:Linear -> ReLU -> Linear
通常中间层维度是输入的4倍(如512->2048->512)
"""

def __init__(self, d_model, d_ff, dropout=0.1):
    """
    Args:
        d_model: 输入/输出维度
        d_ff: 中间层维度(通常是d_model的4倍)
        dropout: dropout概率
    """
    super(FeedForward, self).__init__()
    
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)
    
def forward(self, x):
    """
    Args:
        x: [batch_size, seq_len, d_model]
    
    Returns:
        output: [batch_size, seq_len, d_model]
    """
    # Linear -> ReLU -> Dropout -> Linear
    return self.linear2(self.dropout(F.relu(self.linear1(x))))

class EncoderLayer(nn.Module):
"""
Transformer编码器层

结构:
1. Multi-Head Self-Attention + 残差连接 + LayerNorm
2. Feed Forward + 残差连接 + LayerNorm
"""

def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
    super(EncoderLayer, self).__init__()
    
    self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
    self.feed_forward = FeedForward(d_model, d_ff, dropout)
    
    # Layer Normalization
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    
    self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
    """
    Args:
        x: [batch_size, seq_len, d_model]
        mask: 注意力掩码
    
    Returns:
        output: [batch_size, seq_len, d_model]
    """
    # 1. Self-Attention + 残差连接 + LayerNorm
    attn_output, _ = self.self_attention(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))
    
    # 2. Feed Forward + 残差连接 + LayerNorm  
    ff_output = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_output))
    
    return x

class DecoderLayer(nn.Module):
"""
Transformer解码器层

结构:
1. Masked Multi-Head Self-Attention + 残差 + LayerNorm
2. Multi-Head Cross-Attention + 残差 + LayerNorm  
3. Feed Forward + 残差 + LayerNorm
"""

def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
    super(DecoderLayer, self).__init__()
    
    self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
    self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
    self.feed_forward = FeedForward(d_model, d_ff, dropout)
    
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    
    self.dropout = nn.Dropout(dropout)

def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
    """
    Args:
        x: 解码器输入 [batch_size, tgt_len, d_model]
        encoder_output: 编码器输出 [batch_size, src_len, d_model]
        src_mask: 源序列掩码
        tgt_mask: 目标序列掩码(下三角掩码)
    
    Returns:
        output: [batch_size, tgt_len, d_model]
    """
    # 1. Masked Self-Attention(防止看到未来信息)
    self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(self_attn_output))
    
    # 2. Cross-Attention(解码器attend到编码器输出)
    cross_attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)
    x = self.norm2(x + self.dropout(cross_attn_output))
    
    # 3. Feed Forward
    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    
    return x

class Transformer(nn.Module):
"""
完整的Transformer模型

包含:
- 输入嵌入 + 位置编码
- N层编码器
- N层解码器  
- 输出线性层
"""

def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, 
             n_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
    """
    Args:
        src_vocab_size: 源语言词汇表大小
        tgt_vocab_size: 目标语言词汇表大小
        d_model: 模型维度
        n_heads: 注意力头数
        n_layers: 编码器/解码器层数
        d_ff: 前馈网络中间层维度
        max_seq_len: 最大序列长度
        dropout: dropout概率
    """
    super(Transformer, self).__init__()
    
    self.d_model = d_model
    
    # 词嵌入层
    self.src_embedding = nn.Embedding(src_vocab_size, d_model)
    self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
    
    # 位置编码
    self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
    
    # 编码器层
    self.encoder_layers = nn.ModuleList([
        EncoderLayer(d_model, n_heads, d_ff, dropout) 
        for _ in range(n_layers)
    ])
    
    # 解码器层
    self.decoder_layers = nn.ModuleList([
        DecoderLayer(d_model, n_heads, d_ff, dropout)
        for _ in range(n_layers)
    ])
    
    # 输出投影层
    self.output_projection = nn.Linear(d_model, tgt_vocab_size)
    
    self.dropout = nn.Dropout(dropout)
    
    # 参数初始化
    self.init_parameters()

def init_parameters(self):
    """Xavier初始化"""
    for p in self.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def encode(self, src, src_mask=None):
    """
    编码器前向传播
    
    Args:
        src: 源序列 [batch_size, src_len]
        src_mask: 源序列掩码
        
    Returns:
        encoder_output: [batch_size, src_len, d_model]
    """
    # 词嵌入 + 位置编码
    src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
    src_emb = self.positional_encoding(src_emb)
    src_emb = self.dropout(src_emb)
    
    # 通过编码器层
    encoder_output = src_emb
    for encoder_layer in self.encoder_layers:
        encoder_output = encoder_layer(encoder_output, src_mask)
        
    return encoder_output

def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
    """
    解码器前向传播
    
    Args:
        tgt: 目标序列 [batch_size, tgt_len]
        encoder_output: 编码器输出 [batch_size, src_len, d_model]
        src_mask: 源序列掩码
        tgt_mask: 目标序列掩码
        
    Returns:
        decoder_output: [batch_size, tgt_len, d_model]
    """
    # 词嵌入 + 位置编码
    tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
    tgt_emb = self.positional_encoding(tgt_emb)
    tgt_emb = self.dropout(tgt_emb)
    
    # 通过解码器层
    decoder_output = tgt_emb
    for decoder_layer in self.decoder_layers:
        decoder_output = decoder_layer(decoder_output, encoder_output, src_mask, tgt_mask)
        
    return decoder_output

def forward(self, src, tgt, src_mask=None, tgt_mask=None):
    """
    完整前向传播
    
    Args:
        src: 源序列 [batch_size, src_len]
        tgt: 目标序列 [batch_size, tgt_len]
        src_mask: 源序列掩码
        tgt_mask: 目标序列掩码
        
    Returns:
        output: [batch_size, tgt_len, tgt_vocab_size]
    """
    # 编码
    encoder_output = self.encode(src, src_mask)
    
    # 解码
    decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
    
    # 输出投影
    output = self.output_projection(decoder_output)
    
    return output

def create_padding_mask(seq, pad_idx=0):
"""
创建padding掩码,遮蔽padding位置

Args:
    seq: [batch_size, seq_len]
    pad_idx: padding token的索引
    
Returns:
    mask: [batch_size, 1, 1, seq_len]
"""
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask

def create_look_ahead_mask(seq_len):
"""
创建下三角掩码,防止解码器看到未来信息

Args:
    seq_len: 序列长度
    
Returns:
    mask: [1, 1, seq_len, seq_len]
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0)

使用示例和训练代码

if name == "main":
# 模型参数
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
n_heads = 8
n_layers = 6
d_ff = 2048
max_seq_len = 100

# 创建模型
model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    d_ff=d_ff,
    max_seq_len=max_seq_len
)

print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

# 模拟数据
batch_size = 32
src_len = 20
tgt_len = 25

src = torch.randint(1, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))

# 创建掩码
src_mask = create_padding_mask(src)
tgt_mask = create_look_ahead_mask(tgt_len) & create_padding_mask(tgt)

# 前向传播
with torch.no_grad():
    output = model(src, tgt, src_mask, tgt_mask)
    print(f"输出形状: {output.shape}")  # [batch_size, tgt_len, tgt_vocab_size]

# 简单训练循环示例
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略padding
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
for epoch in range(3):
    # 前向传播
    output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
    
    # 计算损失(预测下一个token)
    target = tgt[:, 1:].contiguous().view(-1)
    output = output.contiguous().view(-1, tgt_vocab_size)
    
    loss = criterion(output, target)
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    
    # 梯度裁剪(防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

print("\n=== Transformer模型实现完成 ===")
posted @ 2025-09-11 16:44  hsr0316  阅读(112)  评论(0)    收藏  举报