验证LSTM内部实现流程,加深对LSTM的印象
验证LSTM内部实现流程,加深对LSTM的印象
LSTM结构图单层
LSTM结构图多层
# 验证经过一个cell的计算
import torch
import torch.nn as nn
# 1. 设置特征
feature_size = 4
batch_size = 1
hidden_size = 10
x = torch.randn(batch_size, feature_size)
# 2. 利用torch自带的lstmcell计算一个节点的ht,ct
lstm = nn.LSTMCell(input_size=feature_size, hidden_size=hidden_size, bias=False)
h0 = torch.zeros(size=(batch_size, hidden_size))
c0 = torch.zeros(size=(batch_size, hidden_size))
ht, ct = lstm(x, (h0, c0))
print(f'调用LSTMCell模块计算{ht}')
print(f'调用LSTMCell模块计算{ct}')
# 3. 手动计算一个lstmcell输出ho,co
# 理论上lstm应该包含4个权重矩阵(wii,wif,wig,wio),但torch里面把4个进行了合并,简化计算
wih = lstm.weight_ih
whh = lstm.weight_hh
# 3.1 将上一步h与这一步x进行合并,后拆分成各个门的输入
ht_1 = torch.mm(input=h0, mat2=torch.t(whh))
xt = torch.mm(input=x, mat2=torch.t(wih))
hx = torch.add(ht_1, xt).reshape(-1, hidden_size)
i, f, g, o = hx[0], hx[1], hx[2], hx[3]
# 3.2 忘记门计算
c1 = torch.multiply(input=c0, other=torch.sigmoid(f))
# 3.2 输入门计算
c2 = torch.add(input=c1, other=torch.multiply(input=torch.sigmoid(i), other=torch.tanh(g)))
# 3.4 输出门计算
co = c2
ho = torch.multiply(input=torch.tanh(c2), other=torch.sigmoid(o))
print(f'手动根据结构图计算{ho}')
print(f'手动根据结构图计算{co}')
'''
调用LSTMCell模块计算tensor([[-0.0115, -0.0040, 0.0376, -0.0131, 0.0128, -0.0104, 0.0382, -0.0359,
-0.0498, 0.0463]], grad_fn=<MulBackward0>)
调用LSTMCell模块计算tensor([[-0.0244, -0.0077, 0.0661, -0.0289, 0.0255, -0.0201, 0.0829, -0.0789,
-0.1051, 0.0888]], grad_fn=<AddBackward0>)
手动根据结构图计算tensor([[-0.0115, -0.0040, 0.0376, -0.0131, 0.0128, -0.0104, 0.0382, -0.0359,
-0.0498, 0.0463]], grad_fn=<MulBackward0>)
手动根据结构图计算tensor([[-0.0244, -0.0077, 0.0661, -0.0289, 0.0255, -0.0201, 0.0829, -0.0789,
-0.1051, 0.0888]], grad_fn=<AddBackward0>)
'''