RNN前向传播

完整的 RNN 前向传播流程,包括:

  1. 输入的 one-hot 编码处理;
  2. 传入 RNN 模块;
  3. 使用全连接层(线性层)输出每个时间步对应的词表预测;
  4. 返回输出和最终的隐藏状态。

下面我们对这段代码逐行进行详细解释,并结合你之前的变量形状推导,帮助你理解每一步发生了什么。


📌 代码片段

X = F.one_hot(inputs.T.long(), self.vocab_size)
X = X.to(torch.float32)
state_list, state = self.rnn(X, state)
output = self.linear(state_list.reshape((-1, state_list.shape[-1])))
return output, state

🧮 变量说明与形状变化

假设参数如下:

  • seq_len:序列长度(T)
  • batch_size:批次大小(N)
  • vocab_size:词表大小(V)
  • hidden_size:RNN 隐藏层维度(H)
  • num_layers:堆叠层数(L)

1. 输入 inputs

假设 inputs 是一个二维张量,表示一批数据中的多个序列,其形状为:

inputs.shape = (batch_size, seq_len)

例如:

inputs = torch.tensor([[1, 2, 3], [4, 5, 6]])  # batch_size=2, seq_len=3

2. 转置输入 + one-hot 编码

X = F.one_hot(inputs.T.long(), self.vocab_size)

步骤解析:

  • inputs.T: 转置后变成 (seq_len, batch_size)
  • F.one_hot(..., vocab_size): 将每个整数索引转换为 one-hot 向量

✅ 所以:

X.shape = (seq_len, batch_size, vocab_size)

✅ 注意:one-hot 编码会自动在最后增加一维,所以结果是三维张量。


3. 转换为浮点类型

X = X.to(torch.float32)
  • 不影响形状,只是确保后续计算支持梯度。
  • 结果仍是:(seq_len, batch_size, vocab_size)

4. 传入 RNN 模块

state_list, state = self.rnn(X, state)
  • self.rnn 是一个 nn.RNN 实例。
  • X 是输入序列:(seq_len, batch_size, vocab_size)
  • state 是初始隐藏状态(可选),默认为 None,此时自动初始化为零张量。

输出:

  • state_list: RNN 的输出,即每一时间步最后一层的隐藏状态
    • shape: (seq_len, batch_size, hidden_size) 或根据 batch_first 设置为 (batch_size, seq_len, hidden_size)
  • state: 最终所有层的隐藏状态
    • shape: (num_layers, batch_size, hidden_size)

5. 全连接层映射到词表空间

output = self.linear(state_list.reshape((-1, state_list.shape[-1])))

目标:

将 RNN 的输出从隐藏空间映射到词表空间(即输出每个词的概率)。

步骤解析:

  • state_list.shape = (seq_len, batch_size, hidden_size)
    (batch_size, seq_len, hidden_size)(取决于 batch_first
  • reshape((-1, state_list.shape[-1])):
    • 展平前两个维度(seq_len * batch_size
    • 得到:(seq_len * batch_size, hidden_size)
  • self.linear 是一个线性层,定义为:
    self.linear = nn.Linear(hidden_size, vocab_size)
    
  • 所以输出:
    output.shape = (seq_len * batch_size, vocab_size)
    

✅ 这个输出可以用于语言模型任务中计算交叉熵损失(CrossEntropyLoss),因为它的形状正好匹配标签格式。


6. 返回值

return output, state
变量名 形状 描述
output (seq_len * batch_size, vocab_size) 每个时间步的词表预测概率分布
state (num_layers, batch_size, hidden_size) 所有层在最后一个时间步的隐藏状态

🔁 完整流程图示

inputs: (batch_size, seq_len)
    ↓
转置 → inputs.T: (seq_len, batch_size)
    ↓
one-hot编码 → X: (seq_len, batch_size, vocab_size)
    ↓
传入 RNN → state_list: (seq_len, batch_size, hidden_size), state: (num_layers, batch_size, hidden_size)
    ↓
展平 + 线性层 → output: (seq_len * batch_size, vocab_size)

📝 示例说明(具体数值)

假设:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 10
  • hidden_size = 20

则:

  • inputs.shape = (2, 3)
  • X.shape = (3, 2, 10)
  • state_list.shape = (3, 2, 20)
  • output.shape = (6, 10)
  • state.shape = (num_layers, 2, 20)

📌 补充说明

  • 如果你想使用 nn.CrossEntropyLoss() 来训练这个模型,你需要:

    • 标签也应该是形状为 (seq_len, batch_size) 的张量(或 (batch_size, seq_len),需对应调整)
    • 损失函数会自动处理 softmax 和 NLL(负对数似然)
  • 如果你希望返回的是 logits(未归一化的分数),而不是 softmax 概率,这一步是正确的。


✅ 总结

你这段代码实现了:

  • 将原始 token ID 转换为 one-hot 编码;
  • 输入 RNN 模型;
  • 使用全连接层将隐藏状态映射到词表空间;
  • 返回最终输出和隐藏状态,可用于下一轮迭代或损失计算。

如果你还想进一步扩展这个模型,比如:

  • 改用嵌入层(nn.Embedding)代替 one-hot 编码;
  • 使用 LSTM/GRU 替代简单 RNN;
  • 添加 dropout、双向结构等;

欢迎继续提问,我可以帮你一步步搭建完整模型。

posted @ 2025-05-21 16:34  玉米面手雷王  阅读(65)  评论(0)    收藏  举报