深入解析:解密Transformer中的前向神经网络(FFN)
2025-10-01 18:44 tlnshuju 阅读(24) 评论(0) 收藏 举报transformer中前向神经网络(FFN)
在Transformer模型中自注意力机制通常抢走了所有风头,但前向神经网络(FFN,也被称为位置级前馈网络) 是Transformer模型中一个不可或缺的、功能强大的组件。
简单来说,它的主要作用是:对自注意力机制提取出的信息进行深加工和变换,为每个位置(单词)生成更丰富、更复杂的特征表示。
让我们用一个比喻来理解:
- 自注意力层 就像一个会议室里的讨论会。每个参会者(单词)都会倾听其他所有人的发言,并基于与所有人的关系来更新自己的观点。这个过程决定了“哪些信息是重要的”。
- 前向神经网络 就像每个参会者回到自己的私人办公室,将讨论会上收集到的所有信息和观点进行深度消化、整合和升华,形成自己更成熟、更复杂的最终立场。这个过程负责“如何加工和处理这些重要信息”。
一、FFN在Transformer中的位置与结构
在Transformer的每个编码器层和解码器层中,FFN都紧跟在自注意力层之后:
输入 -> 自注意力层 -> 残差连接 & 层归一化 -> FFN层 -> 残差连接 & 层归一化 -> 输出
它的结构非常简单,就是一个两层的全连接神经网络:
- 第一层(扩展层): 一个全连接层,通常使用 ReLU 或 GELU 作为激活函数。关键点是,这一层的输出维度远大于输入维度(通常是4倍,例如输入是512维,输出是2048维)。这就像一个“信息膨胀”的过程。
- 第二层(收缩层): 另一个全连接层,不使用激活函数(或可视为线性激活)。这一层将膨胀后的维度投影回原始的模型维度(例如从2048维变回512维)。
公式表示:FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
(其中 W₁
和 W₂
是可学习的权重矩阵,b₁
和 b₂
是偏置项)
二、FFN的具体作用与重要性
1. 提供非线性变换能力
这是FFN最核心的作用。自注意力层本质上是线性操作(加权求和),即使加了Softmax,整个变换的非线性能力仍然有限。FFN中的激活函数(如ReLU)引入了强大的非线性,使得模型能够学习并拟合极其复杂的函数。没有FFN,Transformer就退化为一个简单的线性模型,表达能力将大打折扣。
2. 独立处理每个位置的信息
FFN是位置独立的。它作用于自注意力层的输出向量的每一个位置(即序列中的每一个词向量),并且对每个位置的处理方式是完全相同的(参数共享)。这意味着:
- 它专注于加工单个位置的综合信息。
- 它不会像自注意力那样在位置之间混入信息,那个工作已经由自注意力层完成了。
3. 一个“专家”网络
你可以将每个FFN层看作一个特征转换专家。它学习到的是如何将自注意力层提供的、经过上下文信息加权过的词表示,映射到另一个更有利于最终任务(如翻译、分类)的表示空间。不同层的FFN可能学习到不同级别的抽象特征:
- 底层FFN: 可能更关注语法、词性等基础特征。
- 高层FFN: 可能更关注语义、逻辑等高级特征。
4. 增加模型的容量(可学习参数)
虽然FFN结构简单,但由于中间层维度很大,它实际上是Transformer模型中参数最多的部分!对于一个有L层、模型维度为d、FFN中间层为4d的Transformer:
- 自注意力层的参数量约为
4d²
(每层)。 - FFN层的参数量约为
2 * d * (4d) = 8d²
(每层)。
所以,FFN贡献了模型大部分的可学习参数,为模型存储知识提供了巨大的空间。
三、一个简化的例子
假设我们的任务是翻译一句话,模型已经学到了“apple”这个词在“I eat an apple”和“Apple Inc.”中含义不同。
自注意力层的工作:
- 在“I eat an apple”中,自注意力机制发现“apple”与“eat”强相关,从而将其识别为水果。
- 在“Apple Inc.”中,自注意力机制发现“Apple”与“Inc.”强相关,从而将其识别为公司。
- 自注意力层为每个位置的“apple”都输出一个已经包含上下文线索的向量。
FFN层的工作:
- FFN接收这个已经带有“水果”或“公司”线索的向量。
- 由于FFN在大量数据上训练过,它已经学会了“水果”相关的向量应该被转换成一种富含甜味、圆形、可食用等语义特征的表示;而“公司”相关的向量应该被转换成一种富含科技、创新、商业等语义特征的表示。
- FFN执行这种复杂的非线性特征变换,为后续的层(或解码器)提供更精炼、更高级的特征。
总结
特性 | 自注意力层 | 前向神经网络 |
---|---|---|
核心功能 | 信息聚合: 收集序列中所有位置的信息,建立全局依赖关系。 | 信息加工: 对每个位置的信息进行非线性变换和深度处理。 |
操作范围 | 序列内所有位置之间(交互式) | 每个位置独立(并行式) |
关键作用 | 决定“关注谁” | 决定“如何转化” |
类比 | 会议室讨论 | 个人深度思考 |
没有自注意力,模型就是“瞎子”,看不到上下文。而没有FFN,模型就是“傻子”,无法对看到的信息进行复杂思考和升华。两者相辅相成,共同构成了Transformer强大的建模能力。
附python代码实现
根据原始论文《Attention Is All You Need》的描述,FFN的定义如下:
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
其中,中间层的维度是模型隐藏层的4倍(即 d_ff = 4 * d_model
)。
基础版本实现(使用PyTorch)
这是最直接、最符合论文描述的实现。
import torch
import torch.nn as nn
class PositionwiseFFN(nn.Module):
"""Position-wise Feed-Forward Network."""
def __init__(self, d_model, d_ff, dropout=0.1):
"""
Args:
d_model: 模型隐藏层的维度(输入和输出的维度)
d_ff: 前向神经网络中间层的维度(通常是d_model的4倍)
dropout: Dropout比率
"""
super(PositionwiseFFN, self).__init__()
# 第一层(扩展层):从 d_model 到 d_ff
self.linear1 = nn.Linear(d_model, d_ff)
# 第二层(收缩层):从 d_ff 回到 d_model
self.linear2 = nn.Linear(d_ff, d_model)
# Dropout层
self.dropout = nn.Dropout(dropout)
# 激活函数:原始论文使用ReLU
self.activation = nn.ReLU()
def forward(self, x):
"""
Args:
x: 输入张量,形状为 [batch_size, seq_len, d_model]
Returns:
输出张量,形状为 [batch_size, seq_len, d_model]
"""
# x -> linear1 -> ReLU -> dropout -> linear2
return self.linear2(self.dropout(self.activation(self.linear1(x))))
更实用的版本(支持GELU和预LN配置)
现代Transformer的实现中,常常使用GELU作为激活函数,并且有时会采用Pre-LayerNorm结构。
import torch
import torch.nn as nn
class PositionwiseFFN(nn.Module):
"""增强版的前向神经网络,支持GELU和Pre-LN"""
def __init__(self, d_model, d_ff, dropout=0.1, activation="relu", pre_layer_norm=False):
"""
Args:
d_model: 模型隐藏层的维度
d_ff: 前向神经网络中间层的维度
dropout: Dropout比率
activation: 激活函数,'relu' 或 'gelu'
pre_layer_norm: 是否使用Pre-LayerNorm(True:在FFN之前做LayerNorm;False:在之后做)
"""
super(PositionwiseFFN, self).__init__()
self.d_model = d_model
self.d_ff = d_ff
self.pre_layer_norm = pre_layer_norm
# 线性层
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# 激活函数
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU() # 在BERT等模型中常用
else:
raise ValueError(f"Unsupported activation: {activation}")
# LayerNorm
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
"""
Args:
x: 输入张量,形状为 [batch_size, seq_len, d_model]
Returns:
输出张量,形状与输入相同
"""
residual = x # 保存残差连接
if self.pre_layer_norm:
# Pre-LN: 先做LayerNorm,再通过FFN
x = self.layer_norm(x)
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.dropout(x)
x = residual + x # 残差连接
else:
# Post-LN (原始论文): 先通过FFN和残差连接,最后做LayerNorm
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.dropout(x)
x = self.layer_norm(x + residual) # 残差连接 + LayerNorm
return x
使用示例
让我们看看如何在代码中使用这个FFN模块。
# 示例使用
if __name__ == "__main__":
# 定义参数
batch_size = 2
seq_len = 10
d_model = 512
d_ff = 2048 # 通常是d_model的4倍
# 创建FFN实例
ffn = PositionwiseFFN(d_model, d_ff, dropout=0.1, activation="gelu", pre_layer_norm=True)
# 创建随机输入(模拟Transformer层的输出)
# 形状: [batch_size, seq_len, d_model]
x = torch.randn(batch_size, seq_len, d_model)
print(f"输入形状: {x.shape}")
# 前向传播
output = ffn(x)
print(f"输出形状: {output.shape}")
print(f"输入和输出形状相同: {x.shape == output.shape}")
# 参数统计
total_params = sum(p.numel() for p in ffn.parameters())
print(f"FFN总参数量: {total_params:,}")
# 计算FFN参数量(理论值)
# linear1: d_model * d_ff + d_ff (权重 + 偏置)
# linear2: d_ff * d_model + d_model (权重 + 偏置)
theoretical_params = (d_model * d_ff + d_ff) + (d_ff * d_model + d_model)
print(f"FFN理论参数量: {theoretical_params:,}")
输出示例:
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入和输出形状相同: True
FFN总参数量: 2,362,368
FFN理论参数量: 2,362,368
四、在完整Transformer层中的集成
这里展示FFN如何与自注意力机制结合形成一个完整的Transformer编码器层:
class TransformerEncoderLayer(nn.Module):
"""Transformer编码器层"""
def __init__(self, d_model, nhead, d_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
# 自注意力层
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# 第一个FFN(在自注意力之后)
self.ffn = PositionwiseFFN(d_model, d_ff, dropout, activation="relu", pre_layer_norm=True)
# Dropout
self.dropout = nn.Dropout(dropout)
# LayerNorm
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# 自注意力部分 (Pre-LN)
residual = x
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, attn_mask=mask)
x = residual + self.dropout(x)
# FFN部分 (Pre-LN)
residual = x
x = self.norm2(x)
x = self.ffn(x) # 这里ffn内部已经包含了残差连接
x = residual + self.dropout(x)
return x