torch.rsample和 Gumbel-Softmax的区别

一、核心概念对比

1. torch.rsample():连续分布的重参数化

  • 应用场景:处理连续概率分布(如高斯分布、拉普拉斯分布等)。
  • 核心思想:将随机变量的采样分解为 确定性变换 和 独立噪声,使梯度可通过噪声传递。
  • 数学形式: 对于高斯分布 \(z \sim \mathcal{N}(\mu, \sigma^2)\),PyTorch 通过以下方式实现重参数化采样:\(z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)\)其中 \(\odot\) 表示逐元素乘法。
  • 代码示例:
    import torch
    import torch.nn.functional as F
    
    def gumbel_softmax(logits, tau=1.0, hard=False):
        # 生成Gumbel噪声
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
        # 添加噪声并应用softmax
        y_soft = F.softmax((logits + gumbel_noise) / tau, dim=-1)
        
        if hard:
            # 直通估计器:前向传播使用one-hot,反向传播使用soft版本
            index = y_soft.max(-1, keepdim=True)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
            return y
        else:
            return y_soft
     

2. Gumbel-Softmax:离散分布的重参数化

  • 应用场景:处理离散概率分布(如分类分布、伯努利分布)。
  • 核心思想:通过引入 Gumbel 噪声和温度软化,将不可微的离散采样转化为可微的连续近似。
  • 数学形式: 对于 logits \(\mathbf{z}\),Gumbel-Softmax 采样为:\(\mathbf{y}_\text{soft} = \text{softmax}\left(\frac{\mathbf{z} + \mathbf{g}}{\tau}\right), \quad g_i = -\log(-\log u_i), \ u_i \sim \text{Uniform}(0,1)\)
  • 代码示例:
    import torch
    import torch.nn.functional as F
    
    def gumbel_softmax(logits, tau=1.0, hard=False):
        # 生成Gumbel噪声
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
        # 添加噪声并应用softmax
        y_soft = F.softmax((logits + gumbel_noise) / tau, dim=-1)
        
        if hard:
            # 直通估计器:前向传播使用one-hot,反向传播使用soft版本
            index = y_soft.max(-1, keepdim=True)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
            return y
        else:
            return y_soft
     

二、关键区别

维度torch.rsample()Gumbel-Softmax
适用分布 连续分布(如高斯、拉普拉斯) 离散分布(如分类、伯努利)
采样性质 精确采样(保持分布的统计特性) 近似采样(通过温度控制离散程度)
输出类型 连续值(如浮点数向量) 连续近似(soft 版本)或离散化(hard 版本)
梯度传递方式 直接通过确定性变换传递 通过 softmax 函数和直通估计器传递
典型应用 VAE 中的连续隐变量、强化学习连续动作 离散隐变量模型、离散动作空间 RL

三、联系与相似性

  1. 核心目标一致: 两者都旨在解决随机采样过程中的梯度传播问题,使模型能够通过梯度下降优化概率分布的参数。
  2. 数学原理同源: 均基于重参数化思想 —— 将随机变量表示为 参数 和 独立噪声 的函数。例如:
    • 高斯分布:\(z = \mu + \sigma \odot \epsilon\);
    • 分类分布:\(y = \text{argmax}(\mathbf{z} + \mathbf{g})\)。
  3. PyTorch 实现: 虽然 PyTorch 没有直接提供 Gumbel-Softmax 的内置函数,但可以通过自定义实现(如上述代码)与rsample()类似的可微采样效果。

四、总结

  • torch.rsample() 是处理连续分布的标准重参数化方法,直接应用于支持该接口的分布类(如NormalLaplace)。
  • Gumbel-Softmax 是专门为离散分布设计的重参数化技巧,通过噪声扰动和软化操作实现可微近似。

两者是针对不同类型分布的互补技术,共同服务于概率模型的可微优化目标。在实际应用中,需根据变量类型(连续或离散)选择合适的重参数化方法。
posted @ 2025-06-15 18:26  有何m不可  阅读(56)  评论(0)    收藏  举报