RNN前向传播
完整的 RNN 前向传播流程,包括:
- 输入的 one-hot 编码处理;
- 传入 RNN 模块;
- 使用全连接层(线性层)输出每个时间步对应的词表预测;
- 返回输出和最终的隐藏状态。
下面我们对这段代码逐行进行详细解释,并结合你之前的变量形状推导,帮助你理解每一步发生了什么。
📌 代码片段
X = F.one_hot(inputs.T.long(), self.vocab_size)
X = X.to(torch.float32)
state_list, state = self.rnn(X, state)
output = self.linear(state_list.reshape((-1, state_list.shape[-1])))
return output, state
🧮 变量说明与形状变化
假设参数如下:
seq_len:序列长度(T)batch_size:批次大小(N)vocab_size:词表大小(V)hidden_size:RNN 隐藏层维度(H)num_layers:堆叠层数(L)
1. 输入 inputs
假设 inputs 是一个二维张量,表示一批数据中的多个序列,其形状为:
inputs.shape = (batch_size, seq_len)
例如:
inputs = torch.tensor([[1, 2, 3], [4, 5, 6]]) # batch_size=2, seq_len=3
2. 转置输入 + one-hot 编码
X = F.one_hot(inputs.T.long(), self.vocab_size)
步骤解析:
inputs.T: 转置后变成(seq_len, batch_size)F.one_hot(..., vocab_size): 将每个整数索引转换为 one-hot 向量
✅ 所以:
X.shape = (seq_len, batch_size, vocab_size)
✅ 注意:one-hot 编码会自动在最后增加一维,所以结果是三维张量。
3. 转换为浮点类型
X = X.to(torch.float32)
- 不影响形状,只是确保后续计算支持梯度。
- 结果仍是:
(seq_len, batch_size, vocab_size)
4. 传入 RNN 模块
state_list, state = self.rnn(X, state)
self.rnn是一个nn.RNN实例。X是输入序列:(seq_len, batch_size, vocab_size)state是初始隐藏状态(可选),默认为None,此时自动初始化为零张量。
输出:
state_list: RNN 的输出,即每一时间步最后一层的隐藏状态- shape:
(seq_len, batch_size, hidden_size)或根据batch_first设置为(batch_size, seq_len, hidden_size)
- shape:
state: 最终所有层的隐藏状态- shape:
(num_layers, batch_size, hidden_size)
- shape:
5. 全连接层映射到词表空间
output = self.linear(state_list.reshape((-1, state_list.shape[-1])))
目标:
将 RNN 的输出从隐藏空间映射到词表空间(即输出每个词的概率)。
步骤解析:
state_list.shape = (seq_len, batch_size, hidden_size)
或(batch_size, seq_len, hidden_size)(取决于batch_first)reshape((-1, state_list.shape[-1])):- 展平前两个维度(
seq_len * batch_size) - 得到:
(seq_len * batch_size, hidden_size)
- 展平前两个维度(
self.linear是一个线性层,定义为:self.linear = nn.Linear(hidden_size, vocab_size)- 所以输出:
output.shape = (seq_len * batch_size, vocab_size)
✅ 这个输出可以用于语言模型任务中计算交叉熵损失(CrossEntropyLoss),因为它的形状正好匹配标签格式。
6. 返回值
return output, state
| 变量名 | 形状 | 描述 |
|---|---|---|
output |
(seq_len * batch_size, vocab_size) |
每个时间步的词表预测概率分布 |
state |
(num_layers, batch_size, hidden_size) |
所有层在最后一个时间步的隐藏状态 |
🔁 完整流程图示
inputs: (batch_size, seq_len)
↓
转置 → inputs.T: (seq_len, batch_size)
↓
one-hot编码 → X: (seq_len, batch_size, vocab_size)
↓
传入 RNN → state_list: (seq_len, batch_size, hidden_size), state: (num_layers, batch_size, hidden_size)
↓
展平 + 线性层 → output: (seq_len * batch_size, vocab_size)
📝 示例说明(具体数值)
假设:
batch_size = 2seq_len = 3vocab_size = 10hidden_size = 20
则:
inputs.shape = (2, 3)X.shape = (3, 2, 10)state_list.shape = (3, 2, 20)output.shape = (6, 10)state.shape = (num_layers, 2, 20)
📌 补充说明
-
如果你想使用
nn.CrossEntropyLoss()来训练这个模型,你需要:- 标签也应该是形状为
(seq_len, batch_size)的张量(或(batch_size, seq_len),需对应调整) - 损失函数会自动处理 softmax 和 NLL(负对数似然)
- 标签也应该是形状为
-
如果你希望返回的是 logits(未归一化的分数),而不是 softmax 概率,这一步是正确的。
✅ 总结
你这段代码实现了:
- 将原始 token ID 转换为 one-hot 编码;
- 输入 RNN 模型;
- 使用全连接层将隐藏状态映射到词表空间;
- 返回最终输出和隐藏状态,可用于下一轮迭代或损失计算。
如果你还想进一步扩展这个模型,比如:
- 改用嵌入层(
nn.Embedding)代替 one-hot 编码; - 使用 LSTM/GRU 替代简单 RNN;
- 添加 dropout、双向结构等;
欢迎继续提问,我可以帮你一步步搭建完整模型。

浙公网安备 33010602011771号