验证LSTM内部实现流程,加深对LSTM的印象

验证LSTM内部实现流程,加深对LSTM的印象

LSTM结构图单层

LSTM结构图多层

# 验证经过一个cell的计算
import torch
import torch.nn as nn

# 1. 设置特征
feature_size = 4
batch_size = 1
hidden_size = 10

x = torch.randn(batch_size, feature_size)

# 2. 利用torch自带的lstmcell计算一个节点的ht,ct
lstm = nn.LSTMCell(input_size=feature_size, hidden_size=hidden_size, bias=False)

h0 = torch.zeros(size=(batch_size, hidden_size))
c0 = torch.zeros(size=(batch_size, hidden_size))
ht, ct = lstm(x, (h0, c0))

print(f'调用LSTMCell模块计算{ht}')
print(f'调用LSTMCell模块计算{ct}')

# 3. 手动计算一个lstmcell输出ho,co
# 理论上lstm应该包含4个权重矩阵(wii,wif,wig,wio),但torch里面把4个进行了合并,简化计算
wih = lstm.weight_ih
whh = lstm.weight_hh

# 3.1 将上一步h与这一步x进行合并,后拆分成各个门的输入
ht_1 = torch.mm(input=h0, mat2=torch.t(whh))
xt = torch.mm(input=x, mat2=torch.t(wih))
hx = torch.add(ht_1, xt).reshape(-1, hidden_size)
i, f, g, o = hx[0], hx[1], hx[2], hx[3]
# 3.2 忘记门计算
c1 = torch.multiply(input=c0, other=torch.sigmoid(f))
# 3.2 输入门计算
c2 = torch.add(input=c1, other=torch.multiply(input=torch.sigmoid(i), other=torch.tanh(g)))
# 3.4 输出门计算
co = c2
ho = torch.multiply(input=torch.tanh(c2), other=torch.sigmoid(o))

print(f'手动根据结构图计算{ho}')
print(f'手动根据结构图计算{co}')


'''
调用LSTMCell模块计算tensor([[-0.0115, -0.0040,  0.0376, -0.0131,  0.0128, -0.0104,  0.0382, -0.0359,
         -0.0498,  0.0463]], grad_fn=<MulBackward0>)
调用LSTMCell模块计算tensor([[-0.0244, -0.0077,  0.0661, -0.0289,  0.0255, -0.0201,  0.0829, -0.0789,
         -0.1051,  0.0888]], grad_fn=<AddBackward0>)
         
         
手动根据结构图计算tensor([[-0.0115, -0.0040,  0.0376, -0.0131,  0.0128, -0.0104,  0.0382, -0.0359,
         -0.0498,  0.0463]], grad_fn=<MulBackward0>)
手动根据结构图计算tensor([[-0.0244, -0.0077,  0.0661, -0.0289,  0.0255, -0.0201,  0.0829, -0.0789,
         -0.1051,  0.0888]], grad_fn=<AddBackward0>)

'''
posted @ 2022-03-14 11:47  旁人怎会懂  阅读(89)  评论(0编辑  收藏  举报