详细介绍:详解 KL 散度的反向传播计算:以三分类神经网络为例

一、先明确反向传播的核心前提

1. 模型结构简化(聚焦关键层)

为了清晰展示,我们简化神经网络结构(实际模型可能有多个隐藏层,但核心逻辑一致):

  • 输入层:手写数字图像的特征向量(如扁平化后的28×28=784维);
  • 隐藏层:1个全连接层(假设输出维度为100,激活函数用ReLU);
  • 输出层:全连接层(输出维度=类别数3,无激活函数,输出为「logits」:z0,z1,z2z_0, z_1, z_2z0,z1,z2);
  • 激活层:Softmax函数(将logits转换为预测概率分布Q(x)=[Q0,Q1,Q2]Q(x) = [Q_0, Q_1, Q_2]Q(x)=[Q0,Q1,Q2])。
2. 损失函数定义

由于大家的目标是对齐PPPQQQ整体损失函数L\mathcal{L}L直接等于KL散度(实际中可能会叠加其他损失项,但此处聚焦KL散度的反向传播):
L=DKL(P∥Q)=∑c=02Pc⋅ln⁡(PcQc)\mathcal{L} = D_{KL}(P \parallel Q) = \sum_{c=0}^2 P_c \cdot \ln\left( \frac{P_c}{Q_c} \right)L=DKL(PQ)=c=02Pcln(QcPc)

  • 其中:PcP_cPc是真实分布的固定值(如P0=0.3,P1=0.5,P2=0.2P_0=0.3, P_1=0.5, P_2=0.2P0=0.3,P1=0.5,P2=0.2),Qc=Softmax(zc)=ezc∑k=02ezkQ_c = \text{Softmax}(z_c) = \frac{e^{z_c}}{\sum_{k=0}^2 e^{z_k}}Qc=Softmax(zc)=k=02ezkezczcz_czc是输出层logits,是模型参数的函数)。
3. 反向传播目标

计算损失 L\mathcal{L}L对模型所有可训练参数(隐藏层权重W1W_1W1、偏置 b1b_1b1;输出层权重W2W_2W2、偏置 b2b_2b2)的梯度,然后用梯度下降更新参数:θ=θ−η⋅∂L∂θ\theta = \theta - \eta \cdot \frac{\partial \mathcal{L}}{\partial \theta}θ=θηθLθ\thetaθ表示任意可训练参数,η\etaη是学习率,如0.001)。

二、关键步骤:推导KL散度对logits的梯度(核心桥梁)

反向传播的核心是“链式法则”,而KL散度对输出层logitszcz_czc 的梯度 ∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}zcL是连接损失和模型参数的关键(基于logits直接由输出层参数计算得到,再往前传播到隐藏层即可)。

1. 简化KL散度公式(方便求导)

先拆分KL散度的表达式:L=∑c=02Pc⋅ln⁡Pc−∑c=02Pc⋅ln⁡Qc\mathcal{L} = \sum_{c=0}^2 P_c \cdot \ln P_c - \sum_{c=0}^2 P_c \cdot \ln Q_cL=c=02PclnPcc=02PclnQc

  • 第一项 ∑c=02Pc⋅ln⁡Pc\sum_{c=0}^2 P_c \cdot \ln P_cc=02PclnPc:是“真实分布的熵”,PcP_cPc是固定值(从数据中统计得到),因此对任何模型参数(包括zcz_czc)的导数都为0
  • 第二项是关键:L=−∑c=02Pc⋅ln⁡Qc\mathcal{L} = - \sum_{c=0}^2 P_c \cdot \ln Q_cL=c=02PclnQc(求导时只需关注这一项)。
2. 代入Softmax公式,

∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}zcL 已知 Qc=ezc∑k=02ezk=ezcZQ_c = \frac{e^{z_c}}{\sum_{k=0}^2 e^{z_k}} = \frac{e^{z_c}}{Z}Qc=k=02ezkezc=Zezc(其中 Z=∑k=02ezkZ = \sum_{k=0}^2 e^{z_k}Z=k=02ezk是归一化常数)。 根据链式法则:∂L∂zc=−∑j=02Pj⋅∂ln⁡Qj∂zc\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{\partial \ln Q_j}{\partial z_c}zcL=j=02PjzclnQj(对每个logitzcz_czc,需考虑它对所有QjQ_jQj的影响,因为ZZZ 包含所有 zkz_kzk)。
进一步展开:∂ln⁡Qj∂zc=1Qj⋅∂Qj∂zc \frac{\partial \ln Q_j}{\partial z_c} = \frac{1}{Q_j} \cdot \frac{\partial Q_j}{\partial z_c}zclnQj=Qj1zcQj
因此: ∂L∂zc=−∑j=02Pj⋅1Qj⋅∂Qj∂zc\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{1}{Q_j} \cdot \frac{\partial Q_j}{\partial z_c}zcL=j=02PjQj1zcQj

3. 利用Softmax的梯度性质(关键简化)

Softmax函数有一个核心梯度性质(必须记住,推导略):∂Qj∂zc=Qj⋅(δc,j−Qc)\frac{\partial Q_j}{\partial z_c} = Q_j \cdot (\delta_{c,j} - Q_c)zcQj=Qj(δc,jQc)

  • 其中 δc,j\delta_{c,j}δc,j是「克罗内克函数」:当c=jc = jc=j 时,δc,j=1\delta_{c,j} = 1δc,j=1;当 c≠jc \neq jc=j 时,δc,j=0\delta_{c,j} = 0δc,j=0。 将这个性质代入上式:∂L∂zc=−∑j=02Pj⋅1Qj⋅Qj⋅(δc,j−Qc)\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{1}{Q_j} \cdot Q_j \cdot (\delta_{c,j} - Q_c)zcL=j=02PjQj1Qj(δc,jQc)
  • QjQ_jQj约分后简化为:∂L∂zc=−∑j=02Pj⋅(δc,j−Qc)\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot (\delta_{c,j} - Q_c)zcL=j=02Pj(δc,jQc)
4. 拆分求和项,最终化简

将求和项拆分为j=cj=cj=cj≠cj \neq cj=c 两部分:

  • j=cj = cj=c 时:δc,j=1\delta_{c,j} = 1δc,j=1,该项为 Pc⋅(1−Qc)P_c \cdot (1 - Q_c)Pc(1Qc)
    • j≠cj \neq cj=c 时:δc,j=0\delta_{c,j} = 0δc,j=0,该项为 Pj⋅(0−Qc)=−Pj⋅QcP_j \cdot (0 - Q_c) = - P_j \cdot Q_cPj(0Qc)=PjQc
  • 因此: ∂L∂zc=−[Pc⋅(1−Qc)−Qc⋅∑j≠cPj]\frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c \cdot (1 - Q_c) - Q_c \cdot \sum_{j \neq c} P_j \right]zcL=Pc(1Qc)Qcj=cPj
  • 又因为 ∑j=02Pj=1\sum_{j=0}^2 P_j = 1j=02Pj=1,所以 ∑j≠cPj=1−Pc\sum_{j \neq c} P_j = 1 - P_cj=cPj=1Pc,代入后: ∂L∂zc=−[Pc−PcQc−Qc(1−Pc)]\frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c - P_c Q_c - Q_c (1 - P_c) \right]zcL=[PcPcQcQc(1Pc)]
  • 展开括号并化简:∂L∂zc=−[Pc−PcQc−Qc+PcQc]=−[Pc−Qc]=Qc−Pc \frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c - P_c Q_c - Q_c + P_c Q_c \right] = - \left[ P_c - Q_c \right] = Q_c - P_czcL=[PcPcQcQc+PcQc]=[PcQc]=QcPc
最终核心结论(震惊的简化!)

KL散度对输出层logitszcz_czc的梯度,竟然简化为:∂L∂zc=Qc−Pc\boxed{\frac{\partial \mathcal{L}}{\partial z_c} = Q_c - P_c}zcL=QcPc

  • 这意味着:每个logitzcz_czc的梯度 = 对应类别的预测概率 - 真实概率。

三、结合例子计算梯度(衔接之前的数值)

假设例子中:

  • 真实分布 P=[P0,P1,P2]=[0.3,0.5,0.2]P = [P_0, P_1, P_2] = [0.3, 0.5, 0.2]P=[P0,P1,P2]=[0.3,0.5,0.2]
  • 预测分布 Q=[Q0,Q1,Q2]=[0.35,0.45,0.2]Q = [Q_0, Q_1, Q_2] = [0.35, 0.45, 0.2]Q=[Q0,Q1,Q2]=[0.35,0.45,0.2]
    代入梯度公式,计算每个logit的梯度:
  • z0z_0z0 的梯度:∂L∂z0=Q0−P0=0.35−0.3=0.05\frac{\partial \mathcal{L}}{\partial z_0} = Q_0 - P_0 = 0.35 - 0.3 = 0.05z0L=Q0P0=0.350.3=0.05
  • z1z_1z1 的梯度:∂L∂z1=Q1−P1=0.45−0.5=−0.05\frac{\partial \mathcal{L}}{\partial z_1} = Q_1 - P_1 = 0.45 - 0.5 = -0.05z1L=Q1P1=0.450.5=0.05
  • z2z_2z2 的梯度:∂L∂z2=Q2−P2=0.2−0.2=0\frac{\partial \mathcal{L}}{\partial z_2} = Q_2 - P_2 = 0.2 - 0.2 = 0z2L=Q2P2=0.20.2=0

梯度结果解读:

  • ∂L∂z0=0.05>0\frac{\partial \mathcal{L}}{\partial z_0} = 0.05 > 0z0L=0.05>0:损失 L\mathcal{L}Lz0z_0z0增大而增大,因此参数更新时要减小z0z_0z0相关的权重(让z0z_0z0变小,进而让Q0Q_0Q0从0.35下降到0.3,贴近P0P_0P0);
  • ∂L∂z1=−0.05<0\frac{\partial \mathcal{L}}{\partial z_1} = -0.05 < 0z1L=0.05<0:损失 L\mathcal{L}Lz1z_1z1增大而减小,因此参数更新时要增大z1z_1z1相关的权重(让z1z_1z1变大,进而让Q1Q_1Q1从0.45上升到0.5,贴近P1P_1P1);
  • ∂L∂z2=0\frac{\partial \mathcal{L}}{\partial z_2} = 0z2L=0Q2Q_2Q2 已完全贴近 P2P_2P2,无需调整与z2z_2z2相关的参数。

四、梯度反向传播到前层参数(完整流程)

有了logits的梯度∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}zcL,接下来通过链式法则反向传播到隐藏层和输入层的参数。我们以输出层权重 W2W_2W2 和偏置 b2b_2b2为例(隐藏层参数同理)。

1. 输出层的线性计算关系

输出层的logitszzz是隐藏层输出hhh 与权重 W2W_2W2、偏置 b2b_2b2的线性组合:z=h⋅W2+b2z = h \cdot W_2 + b_2z=hW2+b2

  • 维度说明(假设隐藏层输出hhh是1×100的向量):
    hhh:1×100(批量大小=1时的隐藏层输出);
    W2W_2W2:100×3(隐藏层到输出层的权重矩阵,每行对应隐藏层一个神经元,每列对应一个类别);
    b2b_2b2:1×3(输出层偏置);
    zzz:1×3(输出层logits)。
2. 计算对输出层权重W2W_2W2 的梯度

根据矩阵求导规则:∂L∂W2=hT⋅∂L∂z\frac{\partial \mathcal{L}}{\partial W_2} = h^T \cdot \frac{\partial \mathcal{L}}{\partial z}W2L=hTzL

  • hTh^ThT:100×1(隐藏层输出的转置);
  • ∂L∂z\frac{\partial \mathcal{L}}{\partial z}zL:1×3(logits的梯度向量,即 [0.05, -0.05, 0]);
  • 结果 ∂L∂W2\frac{\partial \mathcal{L}}{\partial W_2}W2L:100×3(与W2W_2W2维度一致,可直接用于更新)。
3. 计算对输出层偏置b2b_2b2 的梯度

偏置 b2b_2b2线性的,导数为1):就是的梯度直接等于logits的梯度(因为偏置对每个logit的贡献∂L∂b2=∂L∂z=[0.05,−0.05,0]\frac{\partial \mathcal{L}}{\partial b_2} = \frac{\partial \mathcal{L}}{\partial z} = [0.05, -0.05, 0]b2L=zL=[0.05,0.05,0]

4. 反向传播到隐藏层

隐藏层的梯度计算同理,利用链式法则:∂L∂h=∂L∂z⋅W2T\frac{\partial \mathcal{L}}{\partial h} = \frac{\partial \mathcal{L}}{\partial z} \cdot W_2^ThL=zLW2T

  • W2TW_2^TW2T:3×100(输出层权重的转置);
  • 结果 ∂L∂h\frac{\partial \mathcal{L}}{\partial h}hL:1×100(与隐藏层输出hhh维度一致)。

再结合隐藏层的激活函数(如ReLU)的导数,可计算出对隐藏层权重W1W_1W1 和偏置 b1b_1b1的梯度,最终完成所有参数的梯度计算。

五、参数更新(梯度下降执行)

得到所有参数的梯度后,用梯度下降法更新参数(以输出层权重W2W_2W2 和偏置 b2b_2b2 为例):

  • 权重更新:W2=W2−η⋅∂L∂W2W_2 = W_2 - \eta \cdot \frac{\partial \mathcal{L}}{\partial W_2}W2=W2ηW2L
  • 偏置更新:b2=b2−η⋅∂L∂b2b_2 = b_2 - \eta \cdot \frac{\partial \mathcal{L}}{\partial b_2}b2=b2ηb2L 假设学习率 η=0.001\eta = 0.001η=0.001,则:
    – 偏置 b0b_0b0(对应 z0z_0z0的偏置)更新:b0=b0−0.001×0.05=b0−0.00005b_0 = b_0 - 0.001 \times 0.05 = b_0 - 0.00005b0=b00.001×0.05=b00.00005(减小偏置,让z0z_0z0 变小,Q0Q_0Q0 下降);
    – 偏置 b1b_1b1(对应 z1z_1z1的偏置)更新:b1=b1−0.001×(−0.05)=b1+0.00005b_1 = b_1 - 0.001 \times (-0.05) = b_1 + 0.00005b1=b10.001×(0.05)=b1+0.00005(增大偏置,让z1z_1z1 变大,Q1Q_1Q1 上升);
    – 偏置 b2b_2b2无更新(梯度为0)。

六、迭代优化效果

经过一次反向传播和参数更新后,模型的logitsz0z_0z0会略微减小,z1z_1z1会略微增大,导致:

  • 预测分布 Q0Q_0Q0从0.35 → 0.33(更贴近P0=0.3P_0=0.3P0=0.3);
  • Q1Q_1Q1从0.45 → 0.47(更贴近P1=0.5P_1=0.5P1=0.5);
  • KL散度从0.0063 → 更小的值(如0.004)。

重复这个过程(迭代训练),直到KL散度收敛到极小值(如0.001以下),模型的预测分布就会完全贴近真实分布。

核心总结:

KL散度反向传播的关键

  1. 梯度简化奇迹:KL散度对logits的梯度最终简化为Qc−PcQ_c - P_cQcPc它在深度学习中广泛使用的原因;就是,无需复杂计算,这也
  2. 反向传播逻辑:损失(KL散度)→ logits梯度 → 输出层参数梯度 → 隐藏层参数梯度 → 梯度下降更新;
  3. 例子呼应:利用具体数值展示了梯度的正负和大小如何指导参数调整,让预测分布逐步贴近真实分布;
  4. 实际代码启示:在TensorFlow/PyTorch中,无需手动推导梯度,框架会自动计算,但理解这个过程能帮你调优模型(如学习率选择、损失函数设计)。

简单来说,KL散度用于反向传播的本质是:通过“预测概率与真实概率的差值”指导参数调整,让模型的输出分布越来越接近目标分布

posted @ 2026-01-07 17:14  gccbuaa  阅读(36)  评论(0)    收藏  举报