第9.3讲、Tiny Transformer: 极简版Transformer
简介
极简版的 Transformer 编码器-解码器(Seq2Seq)结构,适合用于学习、实验和小型序列到序列(如翻译、摘要)任务。
该实现包含了位置编码、多层编码器、多层解码器、训练与推理流程,代码简洁易懂,便于理解 Transformer 的基本原理。
主要结构
- PositionalEncoding:正弦/余弦位置编码,为输入embedding添加位置信息。
- TransformerEncoderLayerWithTrace:单层编码器,含自注意力和前馈网络。
- TinyTransformer:多层堆叠的编码器。
- TransformerDecoderLayer:单层解码器,含自注意力、交叉注意力和前馈网络。
- TransformerDecoder:多层堆叠的解码器。
- TinyTransformerSeq2Seq:编码器-解码器整体结构。
- Seq2SeqDataset:简单的序列到序列数据集。
- train:训练循环。
- greedy_decode:贪婪解码推理。
- generate_subsequent_mask:生成自回归mask。
依赖环境
- Python 3.7+
- torch >= 1.10
安装 PyTorch(以 CPU 版本为例):
pip install torch
用法示例
1. 构建模型
from demo import TinyTransformerSeq2Seq
src_vocab_size = 1000
trg_vocab_size = 1000
model = TinyTransformerSeq2Seq(src_vocab_size, trg_vocab_size)
2. 构造数据集和 DataLoader
from demo import Seq2SeqDataset
from torch.utils.data import DataLoader
src_data = torch.randint(0, src_vocab_size, (100, 10)) # 100个样本,每个10个token
trg_data = torch.randint(0, trg_vocab_size, (100, 12)) # 100个样本,每个12个token
dataset = Seq2SeqDataset(src_data, trg_data)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
3. 训练模型
import torch.optim as optim
import torch.nn as nn
from demo import train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
epoch_loss = train(model, dataloader, optimizer, criterion, tgt_pad_idx=0)
print(f"Train loss: {epoch_loss}")
4. 推理(贪婪解码)
from demo import greedy_decode
src = torch.randint(0, src_vocab_size, (1, 10)).to(device)
sos_idx = 1 # 假设1为<sos>
eos_idx = 2 # 假设2为<eos>
max_len = 12
output = greedy_decode(model, src, sos_idx, eos_idx, max_len)
print("Output token ids:", output)
注意事项
- 该实现为教学/实验用途,未包含完整的mask、权重初始化、分布式训练等工业级细节。
- 需要自行准备合适的训练数据和词表。
- 若需工业级NLP任务,建议使用 HuggingFace Transformers。
tiny Transformer案例代码1:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import DataLoader, Dataset
# =========================
# 位置编码模块
# =========================
class PositionalEncoding(nn.Module):
"""
为输入的embedding添加位置信息,帮助模型捕捉序列顺序。
"""
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
# =========================
# 编码器层
# =========================
class TransformerEncoderLayerWithTrace(nn.Module):
"""
单层Transformer编码器,带自注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
# src: (batch, seq_len, d_model)
src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask)
src = self.norm1(src + self.dropout(src2))
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
src = self.norm2(src + self.dropout(src2))
return src, attn_weights
# =========================
# 编码器
# =========================
class TinyTransformer(nn.Module):
"""
多层堆叠的Transformer编码器。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, dim_feedforward, max_len, num_layers):
super().__init__()
self.embedding = nn.Embedding(src_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerEncoderLayerWithTrace(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
def forward(self, src, trace=False):
# src: (batch, seq_len)
src = self.embedding(src)
src = self.pos_encoder(src)
attn_weights_all = []
for layer in self.layers:
src, attn_weights = layer(src)
if trace:
attn_weights_all.append(attn_weights)
return src, attn_weights_all
# =========================
# 解码器层
# =========================
class TransformerDecoderLayer(nn.Module):
"""
单层Transformer解码器,含自注意力、交叉注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# tgt: (batch, tgt_seq_len, d_model)
# memory: (batch, src_seq_len, d_model)
tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = self.norm1(tgt + self.dropout(tgt2))
tgt2, _ = self.cross_attn(tgt, memory, memory, attn_mask=memory_mask)
tgt = self.norm2(tgt + self.dropout(tgt2))
tgt2 = self.linear2(self.dropout(torch.relu(self.linear1(tgt))))
tgt = self.norm3(tgt + self.dropout(tgt2))
return tgt
# =========================
# 解码器
# =========================
class TransformerDecoder(nn.Module):
"""
多层堆叠的Transformer解码器。
"""
def __init__(self, vocab_size, d_model, nhead, dim_feedforward, num_layers, max_len=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
self.out_proj = nn.Linear(d_model, vocab_size)
def forward(self, tgt_ids, memory, tgt_mask=None):
# tgt_ids: (batch, tgt_seq_len)
x = self.embedding(tgt_ids)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, memory, tgt_mask)
return self.out_proj(x)
# =========================
# Seq2Seq整体模型
# =========================
class TinyTransformerSeq2Seq(nn.Module):
"""
编码器-解码器结构的Transformer模型。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, heads=4, d_ff=128, num_layers=2, max_len=64):
super().__init__()
self.encoder = TinyTransformer(src_vocab_size, tgt_vocab_size, d_model, heads, d_ff, max_len, num_layers)
self.decoder = TransformerDecoder(
vocab_size=tgt_vocab_size,
d_model=d_model,
nhead=heads,
dim_feedforward=d_ff,
num_layers=num_layers,
max_len=max_len
)
def forward(self, src_ids, tgt_input_ids, tgt_mask=None):
# src_ids: (batch, src_seq_len)
# tgt_input_ids: (batch, tgt_seq_len)
memory, _ = self.encoder(src_ids, trace=False)
logits = self.decoder(tgt_input_ids, memory, tgt_mask)
return logits
# =========================
# 工具函数: 生成自回归mask
# =========================
def generate_subsequent_mask(size):
"""
生成自回归mask,防止解码器看到未来的信息。
"""
return torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
# =========================
# Toy数据集
# =========================
class Seq2SeqDataset(Dataset):
"""
简单的序列到序列数据集。
"""
def __init__(self, src_data, tgt_data):
self.src = src_data
self.tgt = tgt_data
def __len__(self):
return len(self.src)
def __getitem__(self, idx):
return self.src[idx], self.tgt[idx]
# =========================
# 训练循环
# =========================
def train(model, dataloader, optimizer, criterion, tgt_pad_idx):
"""
训练模型一个epoch。
"""
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
tgt_mask = generate_subsequent_mask(tgt_input.size(1)).to(device)
logits = model(src, tgt_input, tgt_mask)
loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# =========================
# 推理(贪婪解码)
# =========================
def greedy_decode(model, src, sos_idx, eos_idx, max_len):
"""
贪婪解码:每步选择概率最大的token,直到eos或最大长度。
"""
model.eval()
src = src.to(device)
memory, _ = model.encoder(src, trace=False)
tgt = torch.ones((src.size(0), 1), dtype=torch.long).fill_(sos_idx).to(device)
for _ in range(max_len - 1):
tgt_mask = generate_subsequent_mask(tgt.size(1)).to(device)
out = model.decoder(tgt, memory, tgt_mask)
next_token = out[:, -1, :].argmax(dim=-1).unsqueeze(1)
tgt = torch.cat([tgt, next_token], dim=1)
if (next_token == eos_idx).all():
break
return tgt
# =========================
# 设备选择
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

心有猛虎,细嗅蔷薇

浙公网安备 33010602011771号