代码改变世界

深入解析:解密Transformer中的前向神经网络(FFN)

2025-10-01 18:44  tlnshuju  阅读(24)  评论(0)    收藏  举报

transformer中前向神经网络(FFN)

FFN
在Transformer模型中自注意力机制通常抢走了所有风头,但前向神经网络(FFN,也被称为位置级前馈网络) 是Transformer模型中一个不可或缺的、功能强大的组件。

简单来说,它的主要作用是:对自注意力机制提取出的信息进行深加工和变换,为每个位置(单词)生成更丰富、更复杂的特征表示。

让我们用一个比喻来理解:

  • 自注意力层 就像一个会议室里的讨论会。每个参会者(单词)都会倾听其他所有人的发言,并基于与所有人的关系来更新自己的观点。这个过程决定了“哪些信息是重要的”。
  • 前向神经网络 就像每个参会者回到自己的私人办公室,将讨论会上收集到的所有信息和观点进行深度消化、整合和升华,形成自己更成熟、更复杂的最终立场。这个过程负责“如何加工和处理这些重要信息”。

一、FFN在Transformer中的位置与结构

在Transformer的每个编码器层和解码器层中,FFN都紧跟在自注意力层之后:

输入 -> 自注意力层 -> 残差连接 & 层归一化 -> FFN层 -> 残差连接 & 层归一化 -> 输出

它的结构非常简单,就是一个两层的全连接神经网络:

  1. 第一层(扩展层): 一个全连接层,通常使用 ReLUGELU 作为激活函数。关键点是,这一层的输出维度远大于输入维度(通常是4倍,例如输入是512维,输出是2048维)。这就像一个“信息膨胀”的过程。
  2. 第二层(收缩层): 另一个全连接层,不使用激活函数(或可视为线性激活)。这一层将膨胀后的维度投影回原始的模型维度(例如从2048维变回512维)。

公式表示:
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
(其中 W₁W₂ 是可学习的权重矩阵,b₁b₂ 是偏置项)

在这里插入图片描述


二、FFN的具体作用与重要性

1. 提供非线性变换能力

这是FFN最核心的作用。自注意力层本质上是线性操作(加权求和),即使加了Softmax,整个变换的非线性能力仍然有限。FFN中的激活函数(如ReLU)引入了强大的非线性,使得模型能够学习并拟合极其复杂的函数。没有FFN,Transformer就退化为一个简单的线性模型,表达能力将大打折扣。

2. 独立处理每个位置的信息

FFN是位置独立的。它作用于自注意力层的输出向量的每一个位置(即序列中的每一个词向量),并且对每个位置的处理方式是完全相同的(参数共享)。这意味着:

  • 它专注于加工单个位置的综合信息
  • 它不会像自注意力那样在位置之间混入信息,那个工作已经由自注意力层完成了。
3. 一个“专家”网络

你可以将每个FFN层看作一个特征转换专家。它学习到的是如何将自注意力层提供的、经过上下文信息加权过的词表示,映射到另一个更有利于最终任务(如翻译、分类)的表示空间。不同层的FFN可能学习到不同级别的抽象特征:

  • 底层FFN: 可能更关注语法、词性等基础特征。
  • 高层FFN: 可能更关注语义、逻辑等高级特征。
4. 增加模型的容量(可学习参数)

虽然FFN结构简单,但由于中间层维度很大,它实际上是Transformer模型中参数最多的部分!对于一个有L层、模型维度为d、FFN中间层为4d的Transformer:

  • 自注意力层的参数量约为 4d²(每层)。
  • FFN层的参数量约为 2 * d * (4d) = 8d²(每层)。
    所以,FFN贡献了模型大部分的可学习参数,为模型存储知识提供了巨大的空间。

三、一个简化的例子

假设我们的任务是翻译一句话,模型已经学到了“apple”这个词在“I eat an apple”和“Apple Inc.”中含义不同。

  1. 自注意力层的工作

    • 在“I eat an apple”中,自注意力机制发现“apple”与“eat”强相关,从而将其识别为水果
    • 在“Apple Inc.”中,自注意力机制发现“Apple”与“Inc.”强相关,从而将其识别为公司
    • 自注意力层为每个位置的“apple”都输出一个已经包含上下文线索的向量
  2. FFN层的工作

    • FFN接收这个已经带有“水果”或“公司”线索的向量。
    • 由于FFN在大量数据上训练过,它已经学会了“水果”相关的向量应该被转换成一种富含甜味、圆形、可食用等语义特征的表示;而“公司”相关的向量应该被转换成一种富含科技、创新、商业等语义特征的表示。
    • FFN执行这种复杂的非线性特征变换,为后续的层(或解码器)提供更精炼、更高级的特征。

总结

特性自注意力层前向神经网络
核心功能信息聚合: 收集序列中所有位置的信息,建立全局依赖关系。信息加工: 对每个位置的信息进行非线性变换和深度处理。
操作范围序列内所有位置之间(交互式)每个位置独立(并行式)
关键作用决定“关注谁决定“如何转化
类比会议室讨论个人深度思考

没有自注意力,模型就是“瞎子”,看不到上下文。而没有FFN,模型就是“傻子”,无法对看到的信息进行复杂思考和升华。两者相辅相成,共同构成了Transformer强大的建模能力。

附python代码实现

根据原始论文《Attention Is All You Need》的描述,FFN的定义如下:

FFN(x) = max(0, xW₁ + b₁)W₂ + b₂

其中,中间层的维度是模型隐藏层的4倍(即 d_ff = 4 * d_model)。


基础版本实现(使用PyTorch)

这是最直接、最符合论文描述的实现。

import torch
import torch.nn as nn
class PositionwiseFFN(nn.Module):
"""Position-wise Feed-Forward Network."""
def __init__(self, d_model, d_ff, dropout=0.1):
"""
Args:
d_model: 模型隐藏层的维度(输入和输出的维度)
d_ff: 前向神经网络中间层的维度(通常是d_model的4倍)
dropout: Dropout比率
"""
super(PositionwiseFFN, self).__init__()
# 第一层(扩展层):从 d_model 到 d_ff
self.linear1 = nn.Linear(d_model, d_ff)
# 第二层(收缩层):从 d_ff 回到 d_model
self.linear2 = nn.Linear(d_ff, d_model)
# Dropout层
self.dropout = nn.Dropout(dropout)
# 激活函数:原始论文使用ReLU
self.activation = nn.ReLU()
def forward(self, x):
"""
Args:
x: 输入张量,形状为 [batch_size, seq_len, d_model]
Returns:
输出张量,形状为 [batch_size, seq_len, d_model]
"""
# x -> linear1 -> ReLU -> dropout -> linear2
return self.linear2(self.dropout(self.activation(self.linear1(x))))

更实用的版本(支持GELU和预LN配置)

现代Transformer的实现中,常常使用GELU作为激活函数,并且有时会采用Pre-LayerNorm结构。

import torch
import torch.nn as nn
class PositionwiseFFN(nn.Module):
"""增强版的前向神经网络,支持GELU和Pre-LN"""
def __init__(self, d_model, d_ff, dropout=0.1, activation="relu", pre_layer_norm=False):
"""
Args:
d_model: 模型隐藏层的维度
d_ff: 前向神经网络中间层的维度
dropout: Dropout比率
activation: 激活函数,'relu' 或 'gelu'
pre_layer_norm: 是否使用Pre-LayerNorm(True:在FFN之前做LayerNorm;False:在之后做)
"""
super(PositionwiseFFN, self).__init__()
self.d_model = d_model
self.d_ff = d_ff
self.pre_layer_norm = pre_layer_norm
# 线性层
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# 激活函数
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU()  # 在BERT等模型中常用
else:
raise ValueError(f"Unsupported activation: {activation}")
# LayerNorm
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
"""
Args:
x: 输入张量,形状为 [batch_size, seq_len, d_model]
Returns:
输出张量,形状与输入相同
"""
residual = x  # 保存残差连接
if self.pre_layer_norm:
# Pre-LN: 先做LayerNorm,再通过FFN
x = self.layer_norm(x)
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.dropout(x)
x = residual + x  # 残差连接
else:
# Post-LN (原始论文): 先通过FFN和残差连接,最后做LayerNorm
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.dropout(x)
x = self.layer_norm(x + residual)  # 残差连接 + LayerNorm
return x

使用示例

让我们看看如何在代码中使用这个FFN模块。

# 示例使用
if __name__ == "__main__":
# 定义参数
batch_size = 2
seq_len = 10
d_model = 512
d_ff = 2048  # 通常是d_model的4倍
# 创建FFN实例
ffn = PositionwiseFFN(d_model, d_ff, dropout=0.1, activation="gelu", pre_layer_norm=True)
# 创建随机输入(模拟Transformer层的输出)
# 形状: [batch_size, seq_len, d_model]
x = torch.randn(batch_size, seq_len, d_model)
print(f"输入形状: {x.shape}")
# 前向传播
output = ffn(x)
print(f"输出形状: {output.shape}")
print(f"输入和输出形状相同: {x.shape == output.shape}")
# 参数统计
total_params = sum(p.numel() for p in ffn.parameters())
print(f"FFN总参数量: {total_params:,}")
# 计算FFN参数量(理论值)
# linear1: d_model * d_ff + d_ff (权重 + 偏置)
# linear2: d_ff * d_model + d_model (权重 + 偏置)
theoretical_params = (d_model * d_ff + d_ff) + (d_ff * d_model + d_model)
print(f"FFN理论参数量: {theoretical_params:,}")

输出示例:

输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入和输出形状相同: True
FFN总参数量: 2,362,368
FFN理论参数量: 2,362,368

四、在完整Transformer层中的集成

这里展示FFN如何与自注意力机制结合形成一个完整的Transformer编码器层:

class TransformerEncoderLayer(nn.Module):
"""Transformer编码器层"""
def __init__(self, d_model, nhead, d_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
# 自注意力层
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# 第一个FFN(在自注意力之后)
self.ffn = PositionwiseFFN(d_model, d_ff, dropout, activation="relu", pre_layer_norm=True)
# Dropout
self.dropout = nn.Dropout(dropout)
# LayerNorm
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# 自注意力部分 (Pre-LN)
residual = x
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, attn_mask=mask)
x = residual + self.dropout(x)
# FFN部分 (Pre-LN)
residual = x
x = self.norm2(x)
x = self.ffn(x)  # 这里ffn内部已经包含了残差连接
x = residual + self.dropout(x)
return x