花书BPTT公式推导
花书第10.2.2节的计算循环神经网络的梯度看了好久,总算是把公式的推导给看懂了,记录一下过程。
首先,对于一个普通的RNN来说,其前向传播过程为:
$$\textbf{a}^{(t)}=\textbf{b}+\textbf{Wh}^{t-1}+\textbf{Ux}^{(t)}$$
$$\textbf{h}^t=tanh(\textbf{a}^{(t)})$$
$$\textbf{o}^{(t)} = \textbf{c} + \textbf{V}\textbf{h}^{(t)}$$
$$\hat{\textbf{y}}^{(t)} = softmax(\textbf{o}^{(t)})$$
先介绍一下等下计算过程中会用到的偏导数:
$$h = tanh(a) = \frac{e^a-e^{-a}}{e^a+e^{-a}}$$
$$\frac{\partial \textbf{h}}{\partial \textbf{a}} = diag(1-\textbf{h}^2)$$
另一个,当$\textbf{y}$采用one-hot并且损失函数$L$为交叉熵时:
$$\frac{\partial L}{\partial \textbf{o}^{(t)}} = \frac{\partial L}{\partial L^{(t)}}\frac{\partial L^{(t)}}{\partial \textbf{o}^{t}} = \hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)}$$
【注】这里涉及到softmax求导的规律,如果不懂的话可以看看:传送门
接下来从RNN的尾部开始,逐步计算隐藏状态$\textbf{h}^t$的梯度。如果$\tau$是最后的时间步,$\textbf{h}^{(\tau)}$就是最后的隐藏输出。
$$\frac{\partial L}{\partial \textbf{h}^{(\tau)}} = \frac{\partial L}{\partial \textbf{o}^{(\tau)}}\frac{\partial \textbf{o}^{(\tau)}}{\partial \textbf{h}^{(\tau)}}= \textbf{V}^T(\hat{\textbf{y}}^{(\tau)}-\textbf{y}^{(\tau)})$$
然后一步步往前计算$\textbf{h}^t$的梯度,注意$\textbf{h}^{(t)}(t<\tau)$同时有$\textbf{o}^{(t)}$和$\textbf{h}^{(t+1)}$两个后续节点,所以:
$$\frac{\partial L}{\partial \textbf{h}^{(t)}}=(\frac{\partial \textbf{h}^{(t+1)}}{\partial \textbf{h}^{(t)}})^T\frac{\partial L}{\partial \textbf{h}^{(t+1)}}+(\frac{\partial \textbf{o}^{(t)}}{\partial \textbf{h}^{(t)}})^T\frac{\partial L}{\partial \textbf{o}^{(t)}}=(\frac{\partial \textbf{h}^{(t+1)}}{\partial \textbf{a}^{(t+1)}} \frac{\partial \textbf{a}^{(t+1)}}{\partial \textbf{h}^{(t)}})^T \frac{\partial L}{\partial \textbf{h}^{(t+1)}}+\textbf{V}^T(\hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)})= \textbf{W}^T(diag(1-(\textbf{h}^{(t+1)})^2))\frac{\partial L}{\partial \textbf{h}^{(t+1)}}+\textbf{V}^T(\hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)})$$
【注】这里的结果和花书有点不一样,不知道是花书有错误还是我这里错了?
剩下的参数计算起来就简单多了:
$$\frac{\partial L}{\partial \textbf{W}} = \sum_{t=1}^{\tau}\frac{\partial L}{\partial \textbf{h}^{(t)}}\frac{\partial \textbf{h}^{(t)}}{\partial \textbf{W}} = \sum_{t=1}^{\tau}\frac{\partial L}{\partial \textbf{h}^{(t)}}\frac{\partial \textbf{h}^{(t)}}{\partial \textbf{a}^{(t)}}\frac{\partial \textbf{a}^{(t)}}{\partial \textbf{W}} = \sum_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}(\textbf{h}^{(t-1)})^T$$
$$\frac{\partial L}{\partial \textbf{b}}= \sum\limits_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}$$
$$\frac{\partial L}{\partial \textbf{U}} =\sum\limits_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}(\textbf{x}^{(t)})^T$$
$$\frac{\partial L}{\partial \textbf{c}} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial \textbf{c}} = \sum\limits_{t=1}^{\tau}\hat{\textbf{y}}^{(t)} - \textbf{y}^{(t)}$$
$$\frac{\partial L}{\partial \textbf{V}} =\sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial \textbf{V}} = \sum\limits_{t=1}^{\tau}(\hat{\textbf{y}}^{(t)} - \textbf{y}^{(t)}) (\textbf{h}^{(t)})^T$$
参考