Rnn 中 state 为什么要detach

不同采样方法的状态初始化

序列数据的两种采样方法(顺序分区随机采样)会导致隐状态初始化逻辑的显著差异。

顺序分区(Sequential Partitioning)

  • 隐状态初始化策略:仅在每个迭代周期的起始位置初始化隐状态。由于相邻小批量的子序列在时间上是连续的(如第i个小批量的最后一个样本与下一个小批量的第一个样本相邻),因此前一个小批量的隐状态会被直接作为下一个样本的初始状态。
  • 梯度传播问题:若不进行干预,隐状态会携带整个迭代周期内所有小批量的梯度信息,导致计算图跨多个小批量扩展。这会显著增加内存占用,并可能引发梯度爆炸/消失问题。
  • 解决方案:在每次处理小批量前,通过detach_()方法切断梯度传播路径。此时,隐状态仍保留前序信息(用于建模序列依赖),但其梯度计算仅限于当前小批量的时间步内。

随机采样(Random Sampling)

  • 隐状态初始化策略:每个迭代周期内所有样本的起始位置是随机选择的,因此每个小批量的隐状态需要独立初始化。这种设计破坏了序列的连续性,隐状态不再传递历史信息。
#@save
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
    """训练网络一个迭代周期(定义见第8章)"""
    ...
    state = None  # 初始化隐状态为None
    for X, Y in train_iter:
        if state is None or use_random_iter:
            # 在第一次迭代或使用随机抽样时初始化state
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                # GRU的隐状态是张量,需直接调用detach_
                state.detach_()
            else:
                # LSTM的隐状态是(h, c)元组,需逐元素detach
                for s in state:
                    s.detach_()
    ...

顺序分区中detach_()的原理图解

不分离隐藏状态时的梯度流

graph TD
    subgraph 第1次迭代
        X1[输入X₁] --> RNN1[RNN单元]
        H0[初始隐状态H₀] --> RNN1
        RNN1 --> Y1[输出Y₁]
        RNN1 --> H1[隐状态H₁]
    end
    
    subgraph 第2次迭代
        X2[输入X₂] --> RNN2[RNN单元]
        H1 --> RNN2
        RNN2 --> Y2[输出Y₂]
        RNN2 --> H2[隐状态H₂]
    end
    
    subgraph 反向传播
        Y1 --> BP1[梯度传播至X₁]
        Y2 --> BP2[梯度传播至X₂]
        BP1 --> H1
        BP2 --> H2
        H1 --> H0
        H2 --> H1
    end

问题:梯度需从Y₂回传至H₂ → H₁ → H₀,导致计算图跨多个小批量扩展,内存消耗大且梯度可能不稳定。

分离隐藏状态后的梯度流

graph TD
    subgraph 第1次迭代
        X1[输入X₁] --> RNN1[RNN单元]
        H0[初始隐状态H₀] --> RNN1
        RNN1 --> Y1[输出Y₁]
        RNN1 --> H1[隐状态H₁]
    end
    
    subgraph 第2次迭代
        X2[输入X₂] --> RNN2[RNN单元]
        H1Detach[H₁.detach_()] --> RNN2
        RNN2 --> Y2[输出Y₂]
        RNN2 --> H2[隐状态H₂]
    end
    
    subgraph 反向传播
        Y1 --> BP1[梯度传播至X₁]
        Y2 --> BP2[梯度传播至X₂]
        BP1 --> H1
        BP2 --> H2
        H1Detach -- 断开梯度连接 --> H0
    end

改进H₁.detach_()切断了H₁与H₀的梯度连接,使Y₂的梯度仅传播到X₂和H₁,从而避免了跨小批量的梯度累积。


通俗讲解:为什么需要detach_()

  • 不分离隐状态:想象你正在读一本小说,每页的内容都依赖于前页的记忆。如果一直记住所有页的内容(隐状态未分离),随着页数增加,你的大脑会因存储过多历史信息而负担过重,甚至出现记忆混乱(梯度爆炸/消失)。

  • 分离隐状态:现在改为每读完一页就清空记忆(仅保留当前页的内容)。这样你可以专注于当前页的理解,而不会被前面的内容干扰。虽然你“忘记”了更早的页,但当前页的信息仍能帮助你理解后续内容(隐状态保留前序信息,但切断历史梯度)。


关键点总结

  1. 顺序分区利用隐状态传递序列依赖,但需通过detach_()限制梯度传播范围。
  2. 随机采样因缺乏序列连续性,需每次独立初始化隐状态。
  3. detach_()的作用是切断梯度流,而非清除隐状态本身的信息。这平衡了序列建模需求与计算效率。
posted @ 2025-05-13 15:18  玉米面手雷王  阅读(55)  评论(0)    收藏  举报