关于Gumbel softmax
以下是包含所有内容的 Markdown 格式:
使用 Gumbel-Softmax 和 One-hot 的区别和示例
1. Gumbel-Softmax 和 One-hot 的区别
定义
-
One-hot 向量:一个严格的离散向量,其中只有一个位置为 1,其余位置为 0。例如,假设类别数为 4,某个 one-hot 向量可能为
[0, 1, 0, 0],表示选择了第 2 个类别。one-hot 是离散的、不可微的。 -
Gumbel-Softmax 向量:通过引入 Gumbel 随机噪声对 softmax 进行变换,从而得到一个“近似” one-hot 的向量。它通常在神经网络训练中使用温度参数
tau控制向量的平滑程度。当tau值很小时,Gumbel-Softmax 向量会接近 one-hot,但仍然保持可微性,使得模型能够端到端优化。- 例如,在
tau较小时,Gumbel-Softmax 可能生成一个类似[0.01, 0.98, 0.01, 0.00]的向量,表示高概率选择第 2 个类别,但它仍然是一个连续的分布,可以用梯度优化。
- 例如,在
特性对比
| 特性 | One-hot 向量 | Gumbel-Softmax 向量 |
|---|---|---|
| 是否离散 | 是 | 近似离散(取决于 tau) |
| 是否可微 | 否 | 是(适合神经网络中的反向传播) |
| 温度控制 | 不存在 | tau 控制平滑度,低 tau 接近 one-hot |
| 应用场景 | 适合生成离散分类标签 | 适合在分类任务中采样单一类别的嵌入向量 |
| 随机性 | 无(固定为某个类别的 one-hot) | 引入了 Gumbel 随机噪声,提供采样灵活性 |
2. 使用场景
-
One-hot 向量:适用于最终决策阶段,例如在测试阶段生成分类标签、索引某个类别或输出离散标签。
-
Gumbel-Softmax:适用于神经网络训练中的中间层或嵌入阶段。它在保持 one-hot 近似的同时仍然允许反向传播,适合在训练过程中从概率分布中采样类别,而不会中断梯度流。
3. 训练过程中的梯度流
在训练中,Gumbel-Softmax 的主要优势在于它生成的向量是可微的,允许通过采样的类别继续进行反向传播。与严格的 one-hot 向量相比,这种“软化的离散”使得模型可以基于不同类别的可能性进行学习,并根据实际分布来调整参数。
示例
通过一个具体的例子来展示原始代码和 Gumbel-Softmax 代码的不同计算效果。
假设输入
- 假设
smooth_prob是模型预测的概率分布,表示在每个位置上某个节点属于不同类别的概率。 - 假设
smooth_prob有 3 个节点(或位置),每个节点在 4 个类别上的概率分布:
smooth_prob = torch.tensor([
[0.1, 0.2, 0.6, 0.1], # 节点 1 的概率分布
[0.3, 0.4, 0.1, 0.2], # 节点 2 的概率分布
[0.25, 0.25, 0.25, 0.25] # 节点 3 的概率分布(均匀分布)
])
- 假设
res_embeddings是一个嵌入矩阵,每个类别的嵌入维度为 2:
res_embeddings = torch.tensor([
[1.0, 0.0], # 类别 1 的嵌入
[0.0, 1.0], # 类别 2 的嵌入
[0.5, 0.5], # 类别 3 的嵌入
[0.2, 0.8] # 类别 4 的嵌入
])
计算过程
1. 原始方法(直接加权)
在原始方法中,通过 smooth_prob 和 res_embeddings 的矩阵乘法实现加权嵌入:
H_original = smooth_prob.mm(res_embeddings)
计算过程:
对于每个节点,结果是按概率对类别嵌入的加权和。
-
节点 1 的结果:
\( H[1] = 0.1 \times [1.0, 0.0] + 0.2 \times [0.0, 1.0] + 0.6 \times [0.5, 0.5] + 0.1 \times [0.2, 0.8] = [0.45, 0.45] \) -
节点 2 的结果:
\( H[2] = 0.3 \times [1.0, 0.0] + 0.4 \times [0.0, 1.0] + 0.1 \times [0.5, 0.5] + 0.2 \times [0.2, 0.8] = [0.23, 0.55] \) -
节点 3 的结果(均匀分布):
\( H[3] = 0.25 \times [1.0, 0.0] + 0.25 \times [0.0, 1.0] + 0.25 \times [0.5, 0.5] + 0.25 \times [0.2, 0.8] = [0.425, 0.575] \)
所以,H_original 的结果为:
tensor([
[0.45, 0.45],
[0.23, 0.55],
[0.425, 0.575]
])
2. 使用 Gumbel-Softmax
在 Gumbel-Softmax 方法中,我们使用 F.gumbel_softmax 对 smooth_prob 进行采样:
import torch.nn.functional as F
sampled_prob = F.gumbel_softmax(smooth_prob, tau=1.0, hard=True) # 近似 one-hot 的采样结果
H_gumbel = sampled_prob.mm(res_embeddings)
假设 Gumbel-Softmax 采样结果为以下近似 one-hot 向量(每行接近 one-hot 格式):
sampled_prob = torch.tensor([
[0, 0, 1, 0], # 节点 1 采样到类别 3
[1, 0, 0, 0], # 节点 2 采样到类别 1
[0, 0, 1, 0] # 节点 3 采样到类别 3
])
计算过程:
在采样结果中,每个节点选择了一个类别的嵌入(近似 one-hot),因此:
- 节点 1 选择了类别 3 的嵌入
[0.5, 0.5] - 节点 2 选择了类别 1 的嵌入
[1.0, 0.0] - 节点 3 选择了类别 3 的嵌入
[0.5, 0.5]
所以,H_gumbel 的结果为:
tensor([
[0.5, 0.5],
[1.0, 0.0],
[0.5, 0.5]
])
对比
- 原始方法:输出是平滑的加权嵌入,反映了每个类别的概率。
- Gumbel-Softmax 方法:输出是从
smooth_prob中采样的类别嵌入,每个节点选择一个类别(类似于 one-hot),更离散化。
适用场景
- 平滑加权:适用于希望概率信息对嵌入有直接贡献的情况。
- Gumbel-Softmax:适合希望从分布中采样一个单一类别嵌入,同时保持可微性。

浙公网安备 33010602011771号