dynamic_rnn转nn.GRU详细记录

(原文发表在知乎专栏上,时间为2020年8月13日)

今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。

直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。

import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializers

input = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)

print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)

print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output)        # (batch size, time steps, features)
print(tf_state)         # (batch size, features) for the final time steps
print('\n')

print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():
    torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float()   # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float()                  # (batch size, features)
pytorch_output_cell = []
for i in range(1):
    pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)
    pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')

print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():
    torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float()        # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float()          # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)

最后的结果如下

input:  (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]]

 [[0.8363702  0.85099391 0.75740424 0.30633335 0.20097122]]

 [[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden:  (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136]
 [0.39461795 0.29340918 0.4515729  0.6921841  0.44068605]
 [0.89315058 0.72514622 0.2925488  0.45433305 0.59910906]]
====================  tensorflow result  ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]]

 [[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]]

 [[0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]
 [0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]
 [0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]], shape=(3, 5), dtype=float64)


====================  rnn cell result  ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
        [0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
        [0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]


====================  rnn result  ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
         [0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
         [0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
         [0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
         [0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])

Process finished with exit code 0
posted @ 2025-11-20 15:22  JCChan  阅读(5)  评论(0)    收藏  举报