RNN(循环神经网络):带“记忆”的神经网络

RNN通俗解读:带“记忆”的神经网络

RNN是Recurrent Neural Network(循环神经网络) 的缩写,核心是解决「序列数据」的处理问题——和CNN处理空间结构数据(如图像)不同,RNN专门处理有“先后顺序”的序列数据(比如文本、语音、时间序列),因为它自带“记忆功能”,能把前一步的信息传递到后一步,就像人读句子时会根据前文理解后文一样。

一、先搞懂:RNN为什么存在?(对比普通神经网络)

普通神经网络(比如你之前写的线性回归、简单CNN)的特点是“输入输出独立”:每一次计算只依赖当前输入,比如用“房子面积”预测房价,输入和输入之间没有关联。

但现实中很多数据是“有顺序的”:

  • 文本:“我爱吃苹果”,“苹果”的含义依赖前面的“吃”;
  • 时间序列:股票价格,今天的价格和昨天、前天的价格相关;
  • 语音:一句话的后一个音节依赖前一个音节。

RNN的核心设计就是让网络“记住”前面的信息,用“循环体”把前一步的状态传递到当前步,从而捕捉序列的时序依赖。

二、RNN的核心工作原理(通俗版)

可以把RNN的每一步计算理解为“当前输入 + 历史记忆 → 新输出 + 新记忆”:

  1. 循环体(核心):RNN有一个重复执行的“循环单元”,每一步接收两个输入——「当前时刻的序列数据」(比如一句话中的第i个单词)、「上一时刻的隐藏状态(记忆)」;
  2. 状态更新:循环单元结合这两个输入,输出「当前时刻的结果」,并更新「隐藏状态」(把当前信息存入“记忆”);
  3. 序列遍历:按顺序处理序列的每一个元素,直到整个序列处理完。

用简单公式总结(不用记,理解逻辑即可):

当前隐藏状态 = f(当前输入 × 输入权重 + 上一隐藏状态 × 循环权重 + 偏置)
当前输出 = g(当前隐藏状态 × 输出权重 + 偏置)

其中f通常是Tanh/ReLU激活函数(你之前学过的!),这也是为什么RNN隐藏层常用Tanh——它的输出范围(-1,1),能更好地传递正负向的“记忆信息”。

三、RNN的基本结构(可视化)

以处理一句话(3个单词)为例:

输入1(第1个单词)→ 循环单元 → 输出1 + 隐藏状态1
输入2(第2个单词)+ 隐藏状态1 → 循环单元 → 输出2 + 隐藏状态2
输入3(第3个单词)+ 隐藏状态2 → 循环单元 → 输出3 + 隐藏状态3

可见:每一步的计算都依赖“上一步的记忆”,这就是RNN的“循环”本质。

四、MindSpore中RNN的使用(结合你熟悉的语法)

MindSpore的nn模块提供了RNNLSTM(RNN的改进版)等类,核心还是通过construct方法定义序列处理逻辑,举个简单例子:

import mindspore as ms
import mindspore.nn as nn

# 定义简单的RNN模型(处理文本序列,输入维度10,隐藏层维度20,1层)
class SimpleRNN(nn.Cell):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(
            input_size=10,  # 每个序列元素的特征维度(比如单词的embedding维度)
            hidden_size=20, # 隐藏状态维度(记忆的容量)
            num_layers=1,   # RNN的层数
            activation='tanh' # 隐藏层激活函数,默认Tanh(适配RNN的记忆传递)
        )
        self.fc = nn.Dense(20, 1)  # 输出层(比如预测序列的下一个值)
    
    def construct(self, x):
        # x的形状:(序列长度, 批量大小, 输入维度)
        rnn_out, hidden = self.rnn(x)  # rnn_out是每一步的输出,hidden是最后一步的隐藏状态
        # 用最后一步的隐藏状态做预测
        output = self.fc(hidden)
        return output

# 模拟序列输入:序列长度5,批量大小2,输入维度10
x = ms.Tensor(np.random.randn(5, 2, 10), dtype=ms.float32)
model = SimpleRNN()
pred = model(x)
print(pred.shape)  # 输出(1, 2, 1):(RNN层数, 批量大小, 输出维度)

五、RNN的关键注意点(结合你之前学的知识)

  1. 激活函数选型:RNN隐藏层优先用Tanh(零中心化,适配序列的正负波动),也可用ReLU,但不要用Sigmoid(梯度消失严重);
  2. 纯RNN的缺陷:处理长序列时(比如超过20个元素),容易出现“梯度消失”,无法记住太久远的信息——因此实际中常用LSTM/GRU(RNN的变体,专门解决长序列记忆问题);
  3. 和CNN的区别:CNN是“空间局部特征提取”,RNN是“时序依赖捕捉”,二者常结合(比如处理视频:CNN提帧特征,RNN提时序特征)。

六、核心总结

RNN就是「带记忆的神经网络」,核心是通过“循环体”传递历史信息,专门处理文本、语音、时间序列等有顺序的数;

  • 普通NN:输入独立,无记忆;
  • CNN:空间特征提取,无时序记忆;
  • RNN:时序特征提取,有记忆。

之前学的激活函数、construct方法等知识,完全可以迁移到RNN的使用中——比如RNN隐藏层的Tanh激活、输出层按任务选Softmax/Sigmoid等,逻辑是相通的。

posted @ 2025-12-13 01:59  wangya216  阅读(2)  评论(0)    收藏  举报