Transformer基础架构搭建与训练测试

从零开始实现Transformer模型架构及训练测试代码

下面是一个完整的Transformer模型实现,包括编码器、解码器、多头注意力机制,以及训练和测试代码。

  1. 实现位置编码(Positional Encoding)
  2. 实现多头注意力机制(Multi-Head Attention)
  3. 实现前馈神经网络(Position-wise Feed Forward Network)
  4. 实现编码器层(Encoder Layer)和编码器(Encoder)
  5. 实现解码器层(Decoder Layer)和解码器(Decoder)
  6. 组合编码器和解码器形成Transformer模型
  7. 定义优化器、损失函数和训练步骤
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from typing import Optional

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # 计算正弦和余弦位置编码
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # 添加批次维度
        
        # 将位置编码注册为缓冲区(不参与训练)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x形状: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return x

# 缩放点积注意力
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        # Q, K, V形状: (batch_size, n_heads, seq_len, d_k)
        d_k = Q.size(-1)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 应用掩码(如果提供)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 计算注意力权重
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 应用注意力权重到值向量
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights

# 多头注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
    
    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        # 保存残差连接
        residual = Q
        
        # 线性投影并分头
        batch_size = Q.size(0)
        
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 应用掩码(如果提供)
        if mask is not None:
            mask = mask.unsqueeze(1)  # 为头维度添加维度
        
        # 计算注意力
        x, attn_weights = self.attention(Q, K, V, mask)
        
        # 连接头并应用最终线性层
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        x = self.W_o(x)
        x = self.dropout(x)
        
        # 添加残差连接和层归一化
        x = self.layer_norm(x + residual)
        
        return x, attn_weights

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        residual = x
        x = self.activation(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.layer_norm(x + residual)
        return x

# 编码器层
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        x, attn_weights = self.self_attn(x, x, x, mask)
        x = self.feed_forward(x)
        return x, attn_weights

# 解码器层
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
    
    def forward(self, x, enc_output, src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None):
        # 自注意力(带目标掩码)
        x, self_attn_weights = self.self_attn(x, x, x, tgt_mask)
        
        # 交叉注意力(查询来自解码器,键值来自编码器)
        x, cross_attn_weights = self.cross_attn(x, enc_output, enc_output, src_mask)
        
        # 前馈网络
        x = self.feed_forward(x)
        
        return x, self_attn_weights, cross_attn_weights

# 编码器
class Encoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, 
                 n_heads: int, d_ff: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        # 嵌入和位置编码
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 通过所有编码器层
        attn_weights_list = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            attn_weights_list.append(attn_weights)
        
        return x, attn_weights_list

# 解码器
class Decoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, 
                 n_heads: int, d_ff: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None):
        # 嵌入和位置编码
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 通过所有解码器层
        self_attn_weights_list = []
        cross_attn_weights_list = []
        for layer in self.layers:
            x, self_attn_weights, cross_attn_weights = layer(
                x, enc_output, src_mask, tgt_mask
            )
            self_attn_weights_list.append(self_attn_weights)
            cross_attn_weights_list.append(cross_attn_weights)
        
        return x, self_attn_weights_list, cross_attn_weights_list

# Transformer模型
class Transformer(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, d_model: int = 512, 
                 n_layers: int = 6, n_heads: int = 8, d_ff: int = 2048, 
                 max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, 
                              d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, 
                              d_ff, max_len, dropout)
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        
        # 参数初始化
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, src, tgt, src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None):
        # 编码器前向传播
        enc_output, enc_attn_weights = self.encoder(src, src_mask)
        
        # 解码器前向传播
        dec_output, dec_self_attn_weights, dec_cross_attn_weights = self.decoder(
            tgt, enc_output, src_mask, tgt_mask
        )
        
        # 线性投影到目标词汇表大小
        output = self.linear(dec_output)
        
        return output, enc_attn_weights, dec_self_attn_weights, dec_cross_attn_weights

# 创建掩码
def create_mask(src, tgt, pad_idx):
    # 源序列掩码 (用于编码器和编码器-解码器注意力)
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    
    # 目标序列掩码 (用于解码器自注意力)
    tgt_len = tgt.size(1)
    tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
    subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool()
    tgt_mask = tgt_mask & subsequent_mask.unsqueeze(0).unsqueeze(0)
    
    return src_mask, tgt_mask

# 训练函数
def train(model, iterator, optimizer, criterion, clip, pad_idx):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src = batch.src
        tgt = batch.tgt
        
        # 创建掩码
        src_mask, tgt_mask = create_mask(src, tgt, pad_idx)
        
        optimizer.zero_grad()
        
        # 前向传播
        output, _, _, _ = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
        
        # 计算损失
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        tgt = tgt[:, 1:].contiguous().view(-1)
        
        loss = criterion(output, tgt)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        # 更新参数
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

# 评估函数
def evaluate(model, iterator, criterion, pad_idx):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            tgt = batch.tgt
            
            # 创建掩码
            src_mask, tgt_mask = create_mask(src, tgt, pad_idx)
            
            # 前向传播
            output, _, _, _ = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
            
            # 计算损失
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            tgt = tgt[:, 1:].contiguous().view(-1)
            
            loss = criterion(output, tgt)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

# 推理函数(贪婪解码)
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol, device):
    model.eval()
    
    # 编码源序列
    with torch.no_grad():
        enc_output, _ = model.encoder(src, src_mask)
    
    # 初始化目标序列
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src).to(device)
    
    for i in range(max_len - 1):
        # 创建目标掩码
        _, tgt_mask = create_mask(src, ys, 0)  # 假设pad_idx=0
        
        with torch.no_grad():
            out, _, _ = model.decoder(ys, enc_output, src_mask, tgt_mask)
            out = model.linear(out[:, -1])
            prob = out.softmax(-1)
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()
            
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src).fill_(next_word).to(device)], dim=1)
            
            if next_word == end_symbol:
                break
    
    return ys

# 示例用法
if __name__ == "__main__":
    # 超参数设置
    SRC_VOCAB_SIZE = 10000
    TGT_VOCAB_SIZE = 10000
    D_MODEL = 512
    N_LAYERS = 6
    N_HEADS = 8
    D_FF = 2048
    MAX_LEN = 100
    DROPOUT = 0.1
    BATCH_SIZE = 32
    N_EPOCHS = 10
    CLIP = 1
    LEARNING_RATE = 0.0005
    
    # 设备配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 初始化模型
    model = Transformer(
        SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, N_LAYERS, 
        N_HEADS, D_FF, MAX_LEN, DROPOUT
    ).to(device)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 假设pad_idx=0
    
    # 打印模型参数数量
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型参数数量: {num_params}")
    
    # 这里应该添加数据加载和预处理代码
    # 由于这是一个示例,我们使用虚拟数据
    
    # 创建虚拟数据
    class DummyData:
        def __init__(self, batch_size, seq_len, vocab_size):
            self.src = torch.randint(1, vocab_size, (batch_size, seq_len))
            self.tgt = torch.randint(1, vocab_size, (batch_size, seq_len + 1))
    
    # 创建虚拟数据加载器
    train_data = [DummyData(BATCH_SIZE, 10, SRC_VOCAB_SIZE) for _ in range(100)]
    val_data = [DummyData(BATCH_SIZE, 10, SRC_VOCAB_SIZE) for _ in range(20)]
    
    # 训练循环
    best_valid_loss = float('inf')
    
    for epoch in range(N_EPOCHS):
        train_loss = 0
        valid_loss = 0
        
        # 训练
        for batch in train_data:
            batch.src = batch.src.to(device)
            batch.tgt = batch.tgt.to(device)
            
            train_loss += train(model, [batch], optimizer, criterion, CLIP, 0)
        
        # 验证
        for batch in val_data:
            batch.src = batch.src.to(device)
            batch.tgt = batch.tgt.to(device)
            
            valid_loss += evaluate(model, [batch], criterion, 0)
        
        train_loss /= len(train_data)
        valid_loss /= len(val_data)
        
        print(f'Epoch: {epoch+1:02}')
        print(f'\tTrain Loss: {train_loss:.3f}')
        print(f'\t Val. Loss: {valid_loss:.3f}')
        
        # 保存最佳模型
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'transformer-model.pt')
    
    # 测试推理
    print("测试推理...")
    test_src = torch.randint(1, SRC_VOCAB_SIZE, (1, 10)).to(device)
    src_mask = (test_src != 0).unsqueeze(1).unsqueeze(2)
    
    # 贪婪解码
    result = greedy_decode(model, test_src, src_mask, MAX_LEN, 1, 2, device)
    print(f"源序列: {test_src.tolist()}")
    print(f"生成序列: {result.tolist()}")

代码说明

这个实现包含了Transformer模型的所有关键组件:

  1. 位置编码:使用正弦和余弦函数为输入序列添加位置信息
  2. 缩放点积注意力:计算查询、键和值之间的注意力权重
  3. 多头注意力:将注意力机制并行化到多个头
  4. 前馈网络:每个注意力层后的全连接网络
  5. 编码器和解码器层:构建Transformer的基本构建块
  6. 完整的Transformer模型:整合编码器和解码器

此外,代码还包括:训练和评估函数,贪婪解码推理函数,掩码生成函数(用于处理填充和未来信息),参数初始化。要使用这个Transformer模型处理特定任务(如机器翻译),需要准备适当的数据集并实现数据加载器。

posted @ 2025-08-20 14:39  Jcpeng_std  阅读(153)  评论(0)    收藏  举报