Transformer 架构全解析
Transformer 是 2017 年谷歌团队在论文 Attention Is All You Need 中提出的基于自注意力机制的序列建模架构,打破了传统循环神经网络(RNN)的串行计算范式,成为自然语言处理(NLP)、计算机视觉(CV)等领域的基础模型架构。
一、是什么:核心概念与关键特征
1. 定义
Transformer 是一种编码器 - 解码器(Encoder - Decoder)结构的深度学习模型,核心驱动力是自注意力机制,能够在并行计算的前提下,高效捕捉序列数据中的长距离依赖关系。
2. 核心内涵
- 以注意力机制替代 RNN 的循环结构,实现序列数据的并行处理。
- 编码器负责对输入序列进行特征提取,解码器负责根据编码器的输出和已生成的目标序列,生成下一个 token。
3. 关键特征
| 特征 | 具体说明 |
|---|---|
| 并行计算 | 无需像 RNN 那样按顺序处理序列,可一次性计算所有位置的特征,大幅提升训练效率 |
| 长距离依赖捕捉 | 自注意力机制让序列中任意位置的 token 能直接关联其他位置的 token,无梯度消失风险 |
| 灵活的注意力机制 | 支持多头注意力、掩码注意力等多种变体,适配不同任务场景 |
| 模块化设计 | 编码器和解码器由多个相同的层堆叠而成,结构清晰,易于扩展 |
二、为什么需要:解决的痛点与应用价值
1. 传统序列模型的核心痛点
在 Transformer 出现前,主流序列模型是 RNN、LSTM、GRU 等循环结构,存在明显缺陷:
- 串行计算:必须按顺序处理序列(先算第 1 个 token,再算第 2 个……),无法充分利用 GPU 的并行计算能力,训练速度慢。
- 长距离依赖弱化:随着序列长度增加,梯度会逐渐消失或爆炸,导致模型难以捕捉远距离 token 之间的关联(比如长文本中开头和结尾的逻辑关系)。
- 计算复杂度高:LSTM 等模型的时间复杂度为 $O(n)$,长序列场景下效率极低。
2. Transformer 的核心优势
- 并行计算:时间复杂度降为 $O(1)$,可一次性处理整个序列,训练效率提升数倍。
- 长距离依赖无衰减:自注意力机制通过直接计算 token 间的关联权重,无视序列长度,精准捕捉远距离依赖。
- 通用适配性:不仅适用于 NLP 的翻译、文本生成任务,还能通过结构调整(如 ViT)应用于 CV 的图像分类、目标检测等任务。
3. 实际应用价值
Transformer 是大语言模型(LLM)的基石,BERT、GPT、T5 等知名模型均基于 Transformer 架构衍生而来;在 CV 领域,Vision Transformer(ViT)将图像切分为 patch 序列,实现了图像任务的性能突破。从机器翻译到智能对话,从图像识别到视频分析,Transformer 已成为跨领域的通用模型架构。
三、核心工作模式:关键要素与运作逻辑
Transformer 的核心是编码器 - 解码器结构,以及支撑该结构的 5 大关键要素:自注意力机制、多头注意力、位置编码、前馈神经网络(FFN)、残差连接与层归一化。
1. 核心要素拆解
(1)自注意力机制(Self-Attention)
自注意力是 Transformer 的核心驱动力,本质是让序列中的每个 token 都能“关注”到序列中所有其他 token,并根据关联程度分配不同的权重,最终生成融合全局信息的特征表示。
- 计算步骤:
- 对输入 token 的嵌入向量生成三个向量:查询向量(Query, Q)、键向量(Key, K)、值向量(Value, V)(通过三个不同的线性层映射得到)。
- 计算 Query 与所有 Key 的点积,得到注意力分数(反映 token 间的关联程度)。
- 对注意力分数进行缩放(除以 $\sqrt{d_k}$,$d_k$ 是 Key 的维度,防止分数过大导致 Softmax 饱和)。
- 用 Softmax 函数将分数归一化为 0 - 1 之间的权重。
- 使用权重对 Value 向量进行加权求和,得到该 token 的自注意力输出。
- 公式:$Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt{d_k}})V$
(2)多头注意力(Multi-Head Attention)
单头自注意力只能捕捉一种维度的关联,多头注意力通过并行运行多个自注意力头,捕捉序列中不同角度的依赖关系,再将结果拼接并线性映射,提升模型的表达能力。
- 计算步骤:
- 将 Q、K、V 分别通过 $h$ 组不同的线性层,得到 $h$ 组子 Q、子 K、子 V。
- 对每组子向量计算自注意力,得到 $h$ 个注意力输出。
- 将 $h$ 个输出拼接,再通过一个线性层,得到最终的多头注意力输出。
- 公式:$MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O$,其中 $head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)$
(3)位置编码(Positional Encoding)
自注意力机制本身不包含序列位置信息(打乱序列顺序,计算结果不变),而序列数据的位置是核心特征(比如“我吃苹果”和“苹果吃我”的语义差异)。位置编码的作用是给每个 token 注入位置信息,使其能区分不同的位置。
- 实现方式:论文中采用正弦余弦位置编码,公式为:
$PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}})$
$PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}})$
其中 $pos$ 是 token 在序列中的位置,$i$ 是编码维度的索引,$d_{model}$ 是嵌入向量的维度。也可使用可学习的位置编码。
(4)前馈神经网络(Feed-Forward Network, FFN)
FFN 是一个两层全连接神经网络,作用是对注意力机制的输出进行非线性变换,进一步提取特征。
- 公式:$FFN(x) = max(0, xW_1 + b_1)W_2 + b_2$
其中第一层使用 ReLU 激活函数,第二层是线性变换。
(5)残差连接与层归一化(Residual Connection & Layer Normalization)
Transformer 的每层输出都采用“残差连接 + 层归一化”的组合,解决深度模型的梯度消失问题,加速训练收敛。
- 公式:$LayerNorm(x + SubLayer(x))$
其中 $SubLayer(x)$ 是该层的子模块输出(如注意力模块或 FFN 模块)。
2. 要素间的关联
- 位置编码与嵌入向量相加,得到带位置信息的输入特征,输入编码器/解码器。
- 编码器的核心是多头自注意力 + FFN,每层都通过残差连接和层归一化传递特征。
- 解码器的核心是掩码多头自注意力 + 编码器 - 解码器注意力 + FFN,掩码防止模型关注未来的 token,编码器 - 解码器注意力让解码器关注编码器的输入特征。
- 所有模块的组合,最终实现“输入序列特征提取 → 目标序列生成”的端到端建模。
四、工作流程:完整链路与流程图
Transformer 的工作流程分为编码器流程、解码器流程和输出流程三部分,以下结合 Mermaid 流程图详细说明。
1. 整体架构流程图
2. 分步工作流程
(1)编码器流程(处理输入序列)
编码器由 $N$ 个相同的编码器层堆叠而成(论文中 $N=6$),每个编码器层包含两个子层:
- 输入预处理:将输入序列的每个 token 转换为嵌入向量,再加上位置编码,得到维度为 $[seq_len, d_{model}]$ 的输入特征。
- 多头自注意力层:对输入特征计算多头自注意力,捕捉输入序列内部的依赖关系。
- 残差连接与层归一化:将注意力输出与原始输入特征相加,再进行层归一化。
- 前馈神经网络层:对归一化后的特征进行非线性变换。
- 残差连接与层归一化:将 FFN 输出与上一层输入相加,再进行层归一化。
- 重复上述步骤 $N$ 次,最终得到编码器的输出 Memory(维度为 $[seq_len, d_{model}]$),作为解码器的输入之一。
(2)解码器流程(生成目标序列)
解码器同样由 $N$ 个相同的解码器层堆叠而成,每个解码器层包含三个子层,且采用自回归生成方式(逐个生成 token):
- 输入预处理:将已生成的目标序列 token 转换为嵌入向量,加上位置编码,得到目标特征。
- 掩码多头自注意力层:与普通多头自注意力的区别是加入掩码(Mask),将未来位置的 token 注意力分数设为 -∞,使模型只能关注已生成的 token,避免信息泄露。
- 残差连接与层归一化:同编码器。
- 编码器 - 解码器注意力层:Query 来自解码器上一层的输出,Key 和 Value 来自编码器的 Memory,让解码器关注输入序列的相关特征(比如翻译任务中,目标词对应输入词的位置)。
- 残差连接与层归一化:同编码器。
- 前馈神经网络层:同编码器。
- 残差连接与层归一化:同编码器。
- 重复上述步骤 $N$ 次,得到解码器的输出特征。
(3)输出流程
- 线性层:将解码器的输出特征映射到词汇表维度(维度为 $[seq_len, vocab_size]$)。
- Softmax 层:将线性层的输出转换为概率分布,概率最大的 token 即为当前生成的 token。
- 将生成的 token 加入目标序列,重复解码器流程,直到生成结束符(EOS)。
五、入门实操:基于 PyTorch 实现简单 Transformer
以下是基于 PyTorch 实现一个小型 Transformer 模型的入门步骤,用于英文到中文的简单翻译任务(核心模块实现)。
1. 环境准备
- 安装依赖:
pip install torch torchtext - 关键库:PyTorch(模型构建)、TorchText(数据处理)
2. 核心模块实现
(1)位置编码实现
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 初始化位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 正弦编码偶数维度,余弦编码奇数维度
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1) # [max_len, 1, d_model]
self.register_buffer('pe', pe) # 不参与训练的参数
def forward(self, x):
# x: [seq_len, batch_size, d_model]
return x + self.pe[:x.size(0), :]
(2)多头注意力实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.n_head = n_head
self.d_k = d_model // n_head # 每个头的维度
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性变换并拆分多头:[batch_size, seq_len, d_model] → [batch_size, n_head, seq_len, d_k]
q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
# 计算注意力分数:Q·K^T / sqrt(d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax 归一化
attn = torch.softmax(scores, dim=-1)
# 加权求和并拼接多头
output = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
# 最终线性变换
output = self.w_o(output)
return output
(3)编码器层与解码器层实现
# 编码器层
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_head, ff_dim, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_head)
self.ffn = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 多头自注意力 + 残差 + 层归一化
attn_output = self.attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# FFN + 残差 + 层归一化
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
# 解码器层
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_head, ff_dim, dropout=0.1):
super().__init__()
self.masked_attn = MultiHeadAttention(d_model, n_head)
self.enc_dec_attn = MultiHeadAttention(d_model, n_head)
self.ffn = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# 掩码多头自注意力
attn1 = self.masked_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn1))
# 编码器-解码器注意力
attn2 = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn2))
# FFN
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_output))
return x
(4)完整 Transformer 模型整合
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_head, n_layers, ff_dim, dropout=0.1):
super().__init__()
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)
self.encoder = nn.Sequential(*[EncoderLayer(d_model, n_head, ff_dim, dropout) for _ in range(n_layers)])
self.decoder = nn.Sequential(*[DecoderLayer(d_model, n_head, ff_dim, dropout) for _ in range(n_layers)])
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 输入嵌入 + 位置编码
src_emb = self.dropout(self.pos_encoding(self.src_embedding(src)))
tgt_emb = self.dropout(self.pos_encoding(self.tgt_embedding(tgt)))
# 编码器输出
enc_output = self.encoder(src_emb, src_mask)
# 解码器输出
dec_output = self.decoder(tgt_emb, enc_output, src_mask, tgt_mask)
# 线性层映射到词汇表
output = self.fc(dec_output)
return output
3. 关键操作要点
- 掩码构建:训练时,解码器的掩码分为填充掩码(屏蔽 padding token)和序列掩码(屏蔽未来 token),需确保掩码维度与注意力分数匹配。
- 超参数选择:$d_{model}$ 通常设为 128/256/512,$n_{head}$ 设为 2/4/8(需满足 $d_{model}$ 能被 $n_{head}$ 整除),$n_{layers}$ 设为 2/6/12。
- 数据处理:需将文本序列转换为整数索引,设置统一的序列长度(过长截断,过短 padding)。
4. 实操注意事项
- 训练时使用交叉熵损失函数,优化器选择 Adam,学习率采用预热策略(论文中使用 warmup_steps=4000)。
- 防止过拟合:加入 Dropout 层(概率 0.1)、使用早停(Early Stopping)、增加训练数据量。
- 自回归推理:生成目标序列时,需逐个 token 生成,将上一步的输出作为下一步的输入。
六、常见问题及解决方案
问题 1:训练过程中出现 NaN(梯度爆炸)
- 现象:模型训练几步后,损失值变为 NaN,无法继续训练。
- 原因:注意力分数过大,导致 Softmax 后梯度爆炸;或学习率过高,参数更新幅度过大。
- 解决方案:
- 梯度裁剪:在反向传播时,限制梯度的最大范数,比如
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)。 - 缩放注意力分数:严格执行 $\frac{QK^T}{\sqrt{d_k}}$ 的缩放操作,避免分数过大导致 Softmax 饱和。
- 降低学习率:将初始学习率从 1e-3 降低到 1e-4 或 1e-5,配合预热策略。
- 梯度裁剪:在反向传播时,限制梯度的最大范数,比如
问题 2:长序列训练效率低、显存不足
- 现象:当序列长度超过 1024 时,训练速度急剧下降,甚至出现显存溢出(OOM)。
- 原因:自注意力的时间复杂度和空间复杂度均为 $O(n^2)$,长序列下计算量和显存占用呈平方增长。
- 解决方案:
- 使用稀疏注意力:只计算 token 周围局部范围的注意力(如 Local Attention),或只关注关键 token(如 Sparse Attention),将复杂度降为 $O(n)$。
- 序列分段处理:将长序列切分为多个短片段,分别编码后再融合特征。
- 使用混合精度训练:采用 FP16 精度替代 FP32,减少显存占用。
问题 3:模型过拟合,测试集性能差
- 现象:训练集损失持续下降,测试集损失先降后升,模型泛化能力弱。
- 原因:模型参数过多、训练数据量不足、缺乏正则化机制。
- 解决方案:
- 数据增强:对文本数据进行同义词替换、随机插入/删除 token 等操作,扩充训练数据。
- 增加正则化:提高 Dropout 概率,加入权重衰减(Weight Decay),限制模型参数的大小。
- 预训练 + 微调:先在大规模通用语料上预训练 Transformer,再在目标任务的小数据集上微调,提升泛化能力。
交付物提议
我可以帮你整理Transformer 核心公式的速记清单,方便你快速记忆和复习,需要吗?

浙公网安备 33010602011771号