关于Gumbel softmax

以下是包含所有内容的 Markdown 格式:


使用 Gumbel-Softmax 和 One-hot 的区别和示例

1. Gumbel-Softmax 和 One-hot 的区别

定义

  • One-hot 向量:一个严格的离散向量,其中只有一个位置为 1,其余位置为 0。例如,假设类别数为 4,某个 one-hot 向量可能为 [0, 1, 0, 0],表示选择了第 2 个类别。one-hot 是离散的、不可微的。

  • Gumbel-Softmax 向量:通过引入 Gumbel 随机噪声对 softmax 进行变换,从而得到一个“近似” one-hot 的向量。它通常在神经网络训练中使用温度参数 tau 控制向量的平滑程度。当 tau 值很小时,Gumbel-Softmax 向量会接近 one-hot,但仍然保持可微性,使得模型能够端到端优化。

    • 例如,在 tau 较小时,Gumbel-Softmax 可能生成一个类似 [0.01, 0.98, 0.01, 0.00] 的向量,表示高概率选择第 2 个类别,但它仍然是一个连续的分布,可以用梯度优化。

特性对比

特性 One-hot 向量 Gumbel-Softmax 向量
是否离散 近似离散(取决于 tau
是否可微 是(适合神经网络中的反向传播)
温度控制 不存在 tau 控制平滑度,低 tau 接近 one-hot
应用场景 适合生成离散分类标签 适合在分类任务中采样单一类别的嵌入向量
随机性 无(固定为某个类别的 one-hot) 引入了 Gumbel 随机噪声,提供采样灵活性

2. 使用场景

  • One-hot 向量:适用于最终决策阶段,例如在测试阶段生成分类标签、索引某个类别或输出离散标签。

  • Gumbel-Softmax:适用于神经网络训练中的中间层或嵌入阶段。它在保持 one-hot 近似的同时仍然允许反向传播,适合在训练过程中从概率分布中采样类别,而不会中断梯度流。

3. 训练过程中的梯度流

在训练中,Gumbel-Softmax 的主要优势在于它生成的向量是可微的,允许通过采样的类别继续进行反向传播。与严格的 one-hot 向量相比,这种“软化的离散”使得模型可以基于不同类别的可能性进行学习,并根据实际分布来调整参数。


示例

通过一个具体的例子来展示原始代码和 Gumbel-Softmax 代码的不同计算效果。

假设输入

  • 假设 smooth_prob 是模型预测的概率分布,表示在每个位置上某个节点属于不同类别的概率。
  • 假设 smooth_prob 有 3 个节点(或位置),每个节点在 4 个类别上的概率分布:
smooth_prob = torch.tensor([
    [0.1, 0.2, 0.6, 0.1],  # 节点 1 的概率分布
    [0.3, 0.4, 0.1, 0.2],  # 节点 2 的概率分布
    [0.25, 0.25, 0.25, 0.25]  # 节点 3 的概率分布(均匀分布)
])
  • 假设 res_embeddings 是一个嵌入矩阵,每个类别的嵌入维度为 2:
res_embeddings = torch.tensor([
    [1.0, 0.0],  # 类别 1 的嵌入
    [0.0, 1.0],  # 类别 2 的嵌入
    [0.5, 0.5],  # 类别 3 的嵌入
    [0.2, 0.8]   # 类别 4 的嵌入
])

计算过程

1. 原始方法(直接加权)

在原始方法中,通过 smooth_probres_embeddings 的矩阵乘法实现加权嵌入:

H_original = smooth_prob.mm(res_embeddings)

计算过程
对于每个节点,结果是按概率对类别嵌入的加权和。

  • 节点 1 的结果:
    \( H[1] = 0.1 \times [1.0, 0.0] + 0.2 \times [0.0, 1.0] + 0.6 \times [0.5, 0.5] + 0.1 \times [0.2, 0.8] = [0.45, 0.45] \)

  • 节点 2 的结果:
    \( H[2] = 0.3 \times [1.0, 0.0] + 0.4 \times [0.0, 1.0] + 0.1 \times [0.5, 0.5] + 0.2 \times [0.2, 0.8] = [0.23, 0.55] \)

  • 节点 3 的结果(均匀分布):
    \( H[3] = 0.25 \times [1.0, 0.0] + 0.25 \times [0.0, 1.0] + 0.25 \times [0.5, 0.5] + 0.25 \times [0.2, 0.8] = [0.425, 0.575] \)

所以,H_original 的结果为:

tensor([
    [0.45, 0.45],
    [0.23, 0.55],
    [0.425, 0.575]
])

2. 使用 Gumbel-Softmax

在 Gumbel-Softmax 方法中,我们使用 F.gumbel_softmaxsmooth_prob 进行采样:

import torch.nn.functional as F

sampled_prob = F.gumbel_softmax(smooth_prob, tau=1.0, hard=True)  # 近似 one-hot 的采样结果
H_gumbel = sampled_prob.mm(res_embeddings)

假设 Gumbel-Softmax 采样结果为以下近似 one-hot 向量(每行接近 one-hot 格式):

sampled_prob = torch.tensor([
    [0, 0, 1, 0],  # 节点 1 采样到类别 3
    [1, 0, 0, 0],  # 节点 2 采样到类别 1
    [0, 0, 1, 0]   # 节点 3 采样到类别 3
])

计算过程
在采样结果中,每个节点选择了一个类别的嵌入(近似 one-hot),因此:

  • 节点 1 选择了类别 3 的嵌入 [0.5, 0.5]
  • 节点 2 选择了类别 1 的嵌入 [1.0, 0.0]
  • 节点 3 选择了类别 3 的嵌入 [0.5, 0.5]

所以,H_gumbel 的结果为:

tensor([
    [0.5, 0.5],
    [1.0, 0.0],
    [0.5, 0.5]
])

对比

  • 原始方法:输出是平滑的加权嵌入,反映了每个类别的概率。
  • Gumbel-Softmax 方法:输出是从 smooth_prob 中采样的类别嵌入,每个节点选择一个类别(类似于 one-hot),更离散化。

适用场景

  • 平滑加权:适用于希望概率信息对嵌入有直接贡献的情况。
  • Gumbel-Softmax:适合希望从分布中采样一个单一类别嵌入,同时保持可微性。
posted @ 2024-10-27 15:03  GraphL  阅读(399)  评论(0)    收藏  举报