补全llm知识体系的地基:梯度爆炸与梯度消失

背景:反向传播与链式求导

  • 神经网络参数的更新依赖于链式求导法则:求损失函数关于待更新参数的偏导
  • 该偏导来源于该层及更深层逐层偏导数的累乘
  • 理想来说,我们希望每层偏导数都在1附近,或者至少大部分在1附近
  • 否则,过多小于1或者过度接近于0的偏导,会导致传递到上层的梯度非常接近0,参数更新幅度微小,这就是梯度消失
  • 或者,过大的梯度累乘会导致传递到上层的梯度过大,使得每次更新幅度过大,模型参数难以收敛或者收敛不够稳定,这就是梯度爆炸

发生的直观表现

  • 发生梯度爆炸时,顶层参数的更新很大,进而导致输入在通过顶层时,该层的输出变化也非常大,从而具有很大的方差
  • 反之,发生梯度消失时,顶层参数几乎不更新,进而输入通过顶层时输出基本不变,顶层失去区分度,方差非常小

常见的处理方法(梯度消失):

  1. 进行归一化/有效的初始化
  • 我们考虑一个简单的“全连接层”场景,即一个线性映射+一个激活函数(设为Sigmoid)
  • 输入标记为x0,则输出x1 = Sigmoid(w*x0+b),损失函数为Loss = L(x1, y)
  • 更新w时,Loss关于w的偏导 = Loss关于x1的偏导 * sigmoid函数在x1位置上的导数 * x0
  • 归一化后
    • post-norm能直接将“sigmoid函数在x1位置上的导数”调整到不会梯度消失的区间
    • pre-norm加上合适的参数初始化,也可以达到相同的效果(把x1调整到N(0,1)或其他相近分布)
  1. 修改激活函数
  • 采用其他函数替代Sigmoid,例如ReLU,能够有效避免“在某个x1区间上,会导致梯度消失”
  • ReLU存在x1为负时导数为0,进一步改进还有elu,LeakyReLU等一系列工作
  1. 残差结构
  • 残差结构把 x1 = Sigmoid(wx0+b) 变成了 x1 = Sigmoid(wx0+b) + x0
  • 假设在x0前面还有其他层w0需要更新,本来它的梯度 = x0关于w0的偏导 * x1关于x0的偏导 * Loss关于x1的偏导
  • 加入残差结构后,x1关于x0的偏导 = sigmoid函数在x1位置上的导数 * w + 1
  • 这个+1能够使得中间这一项不可能太小,从而避免了梯度消失

常见的处理方法(梯度爆炸):

  1. 梯度剪切
  • 根据当前迭代轮次和预设定的参数,剪切掉模长/绝对值等大于一定程度的梯度,防止过大的梯度用于更新和传导
  1. 正则化
  • 在Loss function中添加正则项,主要是L2正则项。该正则项把w引入了梯度中,即直接加入了梯度惩罚项。参数绝对值越大,梯度更新就会使参数越小,故而能够限制w大小,w减小之后,梯度也会随之减小。
  • 补充:其他正则化及作用
    • L1正则化:使部分权重趋向于0,能实现特征选择
    • L2正则化:防止过拟合、梯度爆炸
    • Dropout:防止过拟合,原理是训练时随机置零一部分输出,使得神经元连接表达多样,在推理时全部启用并按置零概率缩放,从而保持输出一致
posted @ 2025-05-18 16:38  Phile-matology  阅读(90)  评论(0)    收藏  举报