LSTM和GRU

概述

长短期记忆 LSTM(Long Short Term Memory),该类型的神经网络可以利用上输入数据的时序信息。对于需要处理与顺序或时间强相关数据的领域(自然语言处理、天气预测等)相当合适。

GRU(Gate Recurrent Unit)可以视为 LSTM 的简化版本。运算量更小,却能达到 LSTM 相当的性能。

介绍 LSTM 之前,要先了解什么是 RNN。

RNN

递归神经网络 RNN(Recurssion Neural Network),通过让网络能接收上一时刻的网络输出达成处理时序数据的目标。

通常来说,网络通过输入 \(x\) 可以得到输出 \(y\)。而 RNN 的思路是将 \(i-1\) 时刻的输出 \(y\) 视为 “状态” \(h^{i-1}\),用为 \(i\) 时刻的网络输入。

如此,网络的输入有两个:\(x^i\),和上一个时刻的输出 \(h^{i-1}\)。网络的输出仍为一个,并且可以作为下一个时刻的网络输入 \(h^i\)

LSTM

RNN 有很多缺点(遗忘、梯度爆炸与梯度消失),现在更多使用 LSTM。

LSTM 引入了单元状态(cell state)的概念。网络的输入现在有 \(x\)、隐藏态(hidden state)\(h^{i-1}\)、单元状态 \(c^{i-1}\)。。

单元状态 \(c^{i}\) 变化很慢,通常是 \(c^{t-1}\) 的基础上加一些数值。而 \(h^i\) 对于不同节点有很大区别。

LSTM 具体细节

\(i\) 时刻,网络先将本次输入 \(x^t\) 和上一隐藏态 \(h^{i-1}\) 拼接,经由四个不同的矩阵(矩阵参数可学习)做乘法,获得用途各异的四个状态 \(z^f\)\(z^i\)\(z^o\)\(z\)

\[\begin{aligned} &z^f=\text{sigmoid}(W^f\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z^i=\text{sigmoid}(W^i\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z^o=\text{sigmoid}(W^o\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z=\text{tanh}(W\cdot \text{concatenate}(x^t,h^{t-1})) \end{aligned} \]

一次运算的步骤如下:

  • \(i-1\) 时刻传来的单元状态 \(c^{i-1}\) 首先与 \(z^f\) 相乘,用于代表记忆的遗忘(forget)

  • \(z^i\)\(z\) 相乘,代表对记忆进行选择,哪些记忆需要记录(information)。上一步经过遗忘处理的 \(c^{i-1}\) 与需要记录的记忆进行加运算,完成记录。

此步完成后,\(c^{i-1}\) 化身为 \(c^{i}\) 作为下一步的单元状态输入

  • 最后用 \(z^o\) 控制输出(output)。\(z^o\) 与上一步的 \(c^{i-1}\)\(\text{tanh}\) 结果相乘,获得本时刻的输出 \(y^i\),并作为下一步的隐藏态输入 \(h^t\)

GRU

GRU 可以实现与 LSTM 相当的性能,且运算量更低。

GRU 具体细节

GRU 没有单元状态 \(c^{i}\)。网络接收两个输入:当前输入 \(x^i\)、上一隐藏状态 \(h^{i-1}\)。两个输入经过两个不同的矩阵(矩阵参数可学习)做乘法,获得两个门控(gate):

\[\begin{aligned} &r=\text{sigmoid}(W^r\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z=\text{sigmoid}(W^z\cdot \text{concatenate}(x^t,h^{t-1})) \end{aligned} \]

\(r\) 为重置门控(reset gate),\(z\) 为更新门控(update gate)。

一次运算的步骤如下:

  • \(i-1\) 时刻传来隐藏状态 \(h^{i-1}\)\(r\) 相乘,获得 \({h^{i-1}}'\)。这一步代表有选择性地保留记忆(遗忘)
  • \({h^{t-1}}'\) 与输入 \(x^i\) 拼接,再乘一个参数可学习的矩阵,取 \(\text{tanh}\) 获得 \(h'\)。这一步让 \(h'\) 记忆了当前时刻的状态(记录)
  • \((1-z)\) 乘上 \(h^{i-1}\),用 \(z\) 乘上 \(h'\),将两者的和视为当前的隐藏状态 \(h^i\)。可见 \(h^i\) 结合了以前的记忆与现在的状态,代表记忆的更新

现在,将 \(h^i\) 视为下一时刻的输入,即完成了一次运算。

参考来源

posted @ 2024-03-27 23:27  倒地  阅读(12)  评论(0编辑  收藏  举报