LSTM的隐藏状态和细胞状态

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)  # 全连接层用于输出预测

    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 隐藏状态和细胞状态还可以是随机的
        # h0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # c0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)        

        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out = self.lstm(x, (h0, c0))
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out

forward 方法内部初始化了两个张量 h0c0 分别代表隐藏状态和细胞状态。这两个状态都设置为零张量,并且它们的维度是根据 num_layersbatch_sizehidden_size 来确定的。这确保了对于每一个新序列,LSTM 都会从零状态开始处理,避免了不同序列之间潜在的状态混淆。

具体来说:

  • h0 是最顶层(即最后一层)LSTM 的初始隐藏状态,形状为 (num_layers * num_directions, batch_size, hidden_size)。
  • c0 是最顶层(即最后一层)LSTM 的初始细胞状态,形状与 h0 相同。

当调用 self.lstm(x, hidden) 时,hidden 就是这个元组 (h0, c0),它告诉 LSTM 使用什么作为序列开始时的内部状态。如果未提供 hidden(即传递 None),那么默认情况下所有隐藏状态和细胞状态都会被初始化为零张量。

    def forward(self, x, hidden=None):
        # 如果没有提供hidden,则初始化为全零张量
        if hidden is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            hidden = (h0, c0)

        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out, (hidden_state, cell_state) = self.lstm(x, hidden)
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out, (hidden_state, cell_state)

如果在连续的序列间保留隐藏状态(例如,在生成任务或某些类型的序列预测任务中),则不应该每次都重新初始化这些状态,而是应该将前一个序列的最终状态传递给下一个序列作为初始状态。

    def forward(self, x):
        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out= self.lstm(x)
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out
posted @ 2025-01-19 15:43  华小电  阅读(448)  评论(0)    收藏  举报