从零理解 RNN:隐状态

学习 RNN 时,最重要的不是一开始就陷入代码细节,而是先理解它到底想解决什么问题。

普通神经网络通常假设每个输入样本相互独立。例如输入一张图片,模型输出一个类别;输入一组特征,模型输出一个预测值。但序列数据不同,当前时刻的内容往往依赖之前的历史。

例如在语言模型中,模型看到:

我 今天 想 去

它要预测下一个 token,不能只看最后一个“去”,而要结合前面的“我 今天 想”。这说明序列建模的核心问题是:

如何让模型在处理当前输入时,仍然能够利用过去的信息?

RNN 的做法是:不直接保存所有历史 token,而是维护一个“隐状态”,用它来压缩过去的信息。


一、RNN 的核心思想:用隐状态保存历史

假设一个序列为:

$ x_1, x_2, x_3, \dots, x_t $

在第 $ t $ 个时间步,RNN 不仅接收当前输入 $ X_t $,还接收上一时刻传下来的隐状态 $ H_{t-1} $。

核心公式是:

$ H_t = \phi(X_tW_{xh} + H_{t-1}W_{hh} + b_h) $

其中:

  • $ X_t $:当前时间步的输入;
  • $ H_{t-1} $:上一时间步的隐状态;
  • $ H_t $:当前时间步的隐状态;
  • $ \phi $:激活函数,例如 $ \tanh $;
  • $ W_{xh}, W_{hh}, b_h $:可训练参数。

可以把它理解成一句话:

当前隐状态 = 当前输入的信息 + 过去记忆的信息。

用 Mermaid 表示数据流:

flowchart LR Xt["当前输入 $ X_t $"] --> A["输入到隐状态变换 $W_{xh}$"] Hprev["上一隐状态 $H_{t-1}$"] --> B["隐状态到隐状态变换 $W_{hh}$"] A --> C["相加 + $b_h$"] B --> C C --> D["激活函数 φ"] D --> Ht["当前隐状态 $H_t$"]

这里的 $ H_t $ 可以理解为模型在第 $ t $ 步之后形成的“当前记忆”。


二、RNN 的输出是如何产生的

得到当前隐状态 $ H_t $ 后,RNN 可以进一步产生当前时间步的输出:

$ O_t = H_tW_{hq} + b_q $

其中:

  • $ H_t $:当前隐状态;
  • $ O_t $:当前时间步输出;
  • $ W_{hq}, b_q $:输出层参数。

在语言模型中,$ O_t $ 通常会被映射到词表大小的维度,然后经过 softmax,得到“下一个 token”的概率分布。

例如词表中有 5 个 token:

["我", "你", "去", "吃", "饭"]

那么某个时间步的输出可以理解为:

下一个 token 是“我”的概率
下一个 token 是“你”的概率
下一个 token 是“去”的概率
下一个 token 是“吃”的概率
下一个 token 是“饭”的概率

三、\(W_{xh}\)\(W_{hh}\)\(W_{hq}\) 的含义

RNN 公式中的几个权重下标很容易让人困惑:

$ H_t = \phi(X_tW_{xh} + H_{t-1}W_{hh} + b_h) $

$ O_t = H_tW_{hq} + b_q $

这些下标不是运算,而是命名约定,用来说明权重连接的是哪两个部分。

参数 含义
\(W_{xh}\) 从输入 \(x\) 到隐状态 \(h\) 的权重
\(W_{hh}\) 从上一隐状态 \(h\) 到当前隐状态 \(h\) 的权重
\(W_{hq}\) 从隐状态 \(h\) 到输出 \(q\) 的权重
\(b_h\) 计算隐状态时的偏置
\(b_q\) 计算输出时的偏置

其中:

  • $ x $:input,表示输入;
  • $ h $:hidden state,表示隐状态;
  • $ q $:output,表示输出。

因此:

参数 含义
\(W_{xh}\) \(X_t\)\(H_t\) 所使用的权重
\(W_{hh}\) \(H_{t-1}\)\(H_t\) 所使用的权重
\(W_{hq}\) \(H_t\)\(O_t\) 所使用的权重

四、RNN 是如何沿时间展开的

虽然 RNN 公式看起来只写了一个时间步,但它会在序列上反复使用。

例如:

$ H_1 = \phi(X_1W_{xh} + H_0W_{hh} + b_h) $

$ H_2 = \phi(X_2W_{xh} + H_1W_{hh} + b_h) $

$ H_3 = \phi(X_3W_{xh} + H_2W_{hh} + b_h) $

可以看到,每一步的隐状态都会传给下一步。

flowchart LR X1["$X_1$"] --> R1["RNN 单元"] H0["$H_0$"] --> R1 R1 --> H1["$H_1$"] H1 --> R2["RNN 单元"] X2["$X_2$"] --> R2 R2 --> H2["$H_2$"] H2 --> R3["RNN 单元"] X3["$X_3$"] --> R3 R3 --> H3["$H_3$"] R1 --> O1["$O_1$"] R2 --> O2["$O_2$"] R3 --> O3["$O_3$"]

这个图里画了多个 RNN 单元,但它们并不是不同的网络,而是同一个 RNN 单元在不同时间步上的重复使用。


五、为什么 RNN 的参数要跨时间步共享

RNN 的一个重要特点是:

所有时间步使用同一套参数。

也就是说,无论序列长度是 10、100 还是 1000,每个时间步都使用同一组:

$ W_{xh}, W_{hh}, b_h, W_{hq}, b_q $

这样做有两个好处。

第一,参数数量不会随着序列长度增加。
如果每个时间步都使用一套独立参数,那么序列越长,参数越多,模型会变得难以训练。

第二,模型学到的是一种通用的序列处理规则。
比如语言模型中,“根据前文预测下一个 token”这个规则应该在每个位置都适用,而不是第 1 个位置用一套规则,第 100 个位置又用另一套规则。

所以 RNN 可以看成:

同一个带记忆的计算单元,在时间轴上反复运行。


六、BPTT、截断 BPTT 和梯度裁剪

RNN 的训练需要沿时间方向反向传播梯度,这叫:

$ BPTT $

即 Backpropagation Through Time,通过时间反向传播。

直观理解就是:

把 RNN 沿时间展开成一个很深的前馈网络,然后做反向传播。

但如果序列很长,完整 BPTT 会带来两个问题:

  1. 计算图太长,显存和计算开销很大;
  2. 梯度容易消失或爆炸。

因此训练时通常使用截断 BPTT:

只在固定长度的时间步内反向传播梯度。

例如 num_steps = 35,就可以理解为每次主要展开 35 个时间步进行训练。

此外,为了防止梯度爆炸,常使用梯度裁剪:

$ g \leftarrow \min\left(1, \frac{\theta}{|g|_2}\right)g $

其中:

  • $ g $:所有参数梯度合并后的梯度向量;
  • $ |g|_2 $:梯度的 L2 范数;
  • $ \theta $:设定的梯度阈值。

如果梯度范数超过阈值,就按比例缩小;如果没有超过,就保持不变。

它的作用是:

防止梯度过大导致训练发散。


七、小结

RNN 的核心不在于复杂公式,而在于一个非常朴素的想法:

当前输入本身不够,还需要结合过去的信息。

RNN 用隐状态 $ H_t $ 来保存历史信息,并在每个时间步重复使用同一套参数:

$ H_t = \phi(X_tW_{xh} + H_{t-1}W_{hh} + b_h) $

这个公式可以看作理解 RNN 的钥匙。

只要理解了:

  • $ H_t $ 是历史信息的压缩表示;
  • $ W_{xh} $ 表示输入到隐状态;
  • $ W_{hh} $ 表示旧隐状态到新隐状态;
  • 所有时间步共享参数;
  • 语言模型中每个时间步都预测下一个 token;
  • 训练时用 BPTT,并常配合截断和梯度裁剪;

那么 RNN 的基本逻辑就已经掌握了。

后续的 GRU 和 LSTM,本质上都是在普通 RNN 的基础上加入更复杂的门控机制,让模型更好地控制“记住什么”和“忘掉什么”。



posted @ 2026-06-05 20:22  icuic  阅读(1)  评论(0)    收藏  举报