torch.rsample和 Gumbel-Softmax的区别
一、核心概念对比
1. torch.rsample():连续分布的重参数化
- 应用场景:处理连续概率分布(如高斯分布、拉普拉斯分布等)。
- 核心思想:将随机变量的采样分解为 确定性变换 和 独立噪声,使梯度可通过噪声传递。
- 数学形式: 对于高斯分布 \(z \sim \mathcal{N}(\mu, \sigma^2)\),PyTorch 通过以下方式实现重参数化采样:\(z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)\)其中 \(\odot\) 表示逐元素乘法。
- 代码示例:
import torch import torch.nn.functional as F def gumbel_softmax(logits, tau=1.0, hard=False): # 生成Gumbel噪声 gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20) # 添加噪声并应用softmax y_soft = F.softmax((logits + gumbel_noise) / tau, dim=-1) if hard: # 直通估计器:前向传播使用one-hot,反向传播使用soft版本 index = y_soft.max(-1, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0) y = (y_hard - y_soft).detach() + y_soft return y else: return y_soft
2. Gumbel-Softmax:离散分布的重参数化
- 应用场景:处理离散概率分布(如分类分布、伯努利分布)。
- 核心思想:通过引入 Gumbel 噪声和温度软化,将不可微的离散采样转化为可微的连续近似。
- 数学形式: 对于 logits \(\mathbf{z}\),Gumbel-Softmax 采样为:\(\mathbf{y}_\text{soft} = \text{softmax}\left(\frac{\mathbf{z} + \mathbf{g}}{\tau}\right), \quad g_i = -\log(-\log u_i), \ u_i \sim \text{Uniform}(0,1)\)
- 代码示例:
import torch import torch.nn.functional as F def gumbel_softmax(logits, tau=1.0, hard=False): # 生成Gumbel噪声 gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20) # 添加噪声并应用softmax y_soft = F.softmax((logits + gumbel_noise) / tau, dim=-1) if hard: # 直通估计器:前向传播使用one-hot,反向传播使用soft版本 index = y_soft.max(-1, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0) y = (y_hard - y_soft).detach() + y_soft return y else: return y_soft
二、关键区别
| 维度 | torch.rsample() | Gumbel-Softmax |
|---|---|---|
| 适用分布 | 连续分布(如高斯、拉普拉斯) | 离散分布(如分类、伯努利) |
| 采样性质 | 精确采样(保持分布的统计特性) | 近似采样(通过温度控制离散程度) |
| 输出类型 | 连续值(如浮点数向量) | 连续近似(soft 版本)或离散化(hard 版本) |
| 梯度传递方式 | 直接通过确定性变换传递 | 通过 softmax 函数和直通估计器传递 |
| 典型应用 | VAE 中的连续隐变量、强化学习连续动作 | 离散隐变量模型、离散动作空间 RL |
三、联系与相似性
-
核心目标一致: 两者都旨在解决随机采样过程中的梯度传播问题,使模型能够通过梯度下降优化概率分布的参数。
-
数学原理同源: 均基于重参数化思想 —— 将随机变量表示为 参数 和 独立噪声 的函数。例如:
- 高斯分布:\(z = \mu + \sigma \odot \epsilon\);
- 分类分布:\(y = \text{argmax}(\mathbf{z} + \mathbf{g})\)。
-
PyTorch 实现: 虽然 PyTorch 没有直接提供 Gumbel-Softmax 的内置函数,但可以通过自定义实现(如上述代码)与
rsample()类似的可微采样效果。
四、总结
torch.rsample()是处理连续分布的标准重参数化方法,直接应用于支持该接口的分布类(如Normal、Laplace)。- Gumbel-Softmax 是专门为离散分布设计的重参数化技巧,通过噪声扰动和软化操作实现可微近似。
两者是针对不同类型分布的互补技术,共同服务于概率模型的可微优化目标。在实际应用中,需根据变量类型(连续或离散)选择合适的重参数化方法。

浙公网安备 33010602011771号