算法学习笔记(13): RNN
RNN 也就是循环神经网络。
其核心公式也就两个:
\[\begin{aligned}
z_t &= W_h h_{t - 1} + W_x x_t + b_1 \\
h_t &= f_1(z_t) \\
o_t &= W_y h_t + b_2 \\
y_t &= f_2(o_t)
\end{aligned}
\]
其核心就是状态的合并,也就是 \(h_t\) 依赖于上一个状态和当前的输入。
输出将状态解码即可。
但是如果我们不考虑激活函数,去展开整个过程,我们会发现有:
\[h_t \leftarrow W_h^{t - t_0} W_x x_{t_0}
\]
如此神秘的东西导致 \(h_t\) 能够依赖的 \(t_0\) 注定不能太长(要么数值爆炸,要么就归于 \(0\) 了,就算是 0.99 经过 100 轮后,也只有 0.36 左右了),一定会有数值衰减(应该不能有数值爆炸的 QwQ)。所以有了一些简单的变体 [[GRU & LSTM]]
考虑第 \(t_0\) 步的损失 \(J_{t_0}\),考虑求偏导:
\[\begin{aligned}
\nabla_{y_t} L &= \frac {\partial L_t} {\partial y_t} \\
\delta_t^o &= \frac {\partial L} {\partial o_t} = \nabla y_t L \odot f'_2(o_t) \\
\delta_t^h &= \frac {\partial o_t} {\partial h_t}^T \delta_t^o + \frac {\partial z_{t + 1}} {\partial h_t}^T \delta_{t + 1}^z \\
&= W_y^T \delta_t^o + W_h^T \delta_{t + 1}^z \\
\delta_t^z &= \delta_t^h \odot f_1'(z_t)
\end{aligned}
\]
于是吗:
\[\nabla_{W_y} L = \sum_{t=1}^T \delta_t^o h_t^\top, \quad \nabla_{b_2} L = \sum_{t=1}^T \delta_t^o
\]
\[\nabla_{W_h} L = \sum_{t=1}^T \delta_t^z h_{t-1}^\top, \quad \nabla_{W_x} L = \sum_{t=1}^T \delta_t^z x_t^\top, \quad \nabla_{b_1} L = \sum_{t=1}^T \delta_t^z
\]
其中 \(\odot\) 为逐元素乘法,\((\cdot)^\top\) 为矩阵转置。权重 \(W_h, W_x, W_y\) 在各时间步共享,故梯度需沿时间维累加,该连乘结构是梯度消失/爆炸(Vanishing/Exploding Gradient)的数学根源。

浙公网安备 33010602011771号