用 pytorch 实现 一个rnn

原文链接

import torch
torch.__version__

class RNN(object):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.W_xh=torch.nn.Linear(input_size,hidden_size) #因为最后的操作是相加 所以hidden要和output 的shape一致
        self.W_hh=torch.nn.Linear(hidden_size,hidden_size)
        
    def __call__(self,x,hidden):
        return self.step(x,hidden)
    def step(self, x, hidden):
        #前向传播的一步
        h1=self.W_hh(hidden)
        w1=self.W_xh(x)
        out = torch.tanh( h1+w1)
        hidden=self.W_hh.weight
        return out,hidden
rnn = RNN(20,50)
input = torch.randn( 32 , 20)
h_0 =torch.randn(32 , 50) :

for i in range(seq_len):
    output,hn= rnn(input[i, :], h_0)
print(output.size(),h_0.size())
seq_len = input.shape[0]


posted @ 2022-08-19 22:51  luoganttcc  阅读(5)  评论(0)    收藏  举报