Classifier-Free Guidance
论文

-
输入要求:算法需要一个无条件训练的概率 \(p_{\text{uncond}}\)。
-
重复执行:
- 数据采样:从数据集中采样带有条件的样本 \((\mathbf{x}, \mathbf{c})\)。
- 条件丢弃:以概率 \(p_{\text{uncond}}\) 随机丢弃条件 \(\mathbf{c}\),以进行无条件训练。
- 采样对数信噪比(log SNR):从分布 \(p(\lambda)\) 中采样 \(\lambda\)。
- 噪声采样:从标准正态分布 \(\mathcal{N}(\mathbf{0}, \mathbf{I})\) 中采样噪声 \(\boldsymbol{\epsilon}\)。
- 数据扰动:使用采样的 \(\lambda\),通过公式 \(\mathbf{z}_\lambda = \alpha_\lambda \mathbf{x} + \sigma_\lambda \boldsymbol{\epsilon}\) 对数据进行扰动。
- 梯度更新:对去噪模型进行优化,更新参数 \(\theta\),使得模型输出的噪声预测 \(\boldsymbol{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c})\) 接近真实噪声 \(\boldsymbol{\epsilon}\)。
-
直到收敛:重复上述步骤,直到模型收敛。

-
输入要求:
- \(w\):引导强度,控制无分类器引导的影响。
- \(\mathbf{c}\):用于条件采样的条件信息。
- \(\lambda_1, \ldots, \lambda_T\):递增的 log SNR 序列,定义了采样过程中的噪声水平变化。
-
初始化:
- 从标准正态分布 \(\mathcal{N}(\mathbf{0}, \mathbf{I})\) 中采样初始噪声 \(\mathbf{z}_1\)。
-
迭代采样:
-
对于每个时间步 \(t\):
- 形成无分类器引导的得分:计算调整后的噪声预测 \(\tilde{\boldsymbol{\epsilon}}_t\),结合了条件和无条件的噪声预测。
- 采样步骤:更新中间变量 \(\tilde{\mathbf{x}}_t\),用于生成下一个时间步的样本。
- 更新样本:根据当前和下一个时间步的 \(\lambda\),更新样本 \(\mathbf{z}_{t+1}\)。如果是最后一个时间步,则直接将 \(\tilde{\mathbf{x}}_t\) 作为最终输出。
-
-
返回结果:
- 返回最终生成的样本 \(\mathbf{z}_{T+1}\)。
Classifier-Free Diffusion Guidance | arXiv
简化
无分类器引导(Classifier-Free Guidance,CFG)是一种用于生成模型的技术,特别是在扩散模型(Diffusion Models)中,用于提高生成样本的质量。其基本思想是通过引入一个控制参数来平衡生成样本的多样性和质量。
设有一个扩散模型,其生成过程可以表示为条件概率分布 \(p(x \mid y)\),其中 \(x\) 是生成的样本,\(y\) 是条件信息(如类别标签)。在 CFG 中,我们使用两个模型:
- 一个条件模型 \(p_\theta(x \mid y)\),用于给定条件 \(y\),生成样本 \(x\)。
- 一个无条件模型 \(p_\theta(x)\),用于不给定任何条件生成样本 \(x\)。
CFG 的核心思想是通过以下公式来引导生成过程:
\[\tilde{p}_\theta(x \mid y) = p_\theta(x) + w (p_\theta(x \mid y) - p_\theta(x))
\]
其中,\(w\) 是一个控制参数,称为引导权重(guidance weight)。这个参数用于平衡条件模型和无条件模型的影响:
- 当 \(w = 1\) 时,\(\tilde{p}_\theta(x \mid y)\) 等于 \(p_\theta(x \mid y)\),即完全依赖条件模型。
- 当 \(w > 1\) 时,引导过程会倾向于生成更符合条件 \(y\) 的样本,但可能会牺牲一些多样性。
- 当 \(w < 1\) 时,生成过程会更倾向于无条件模型,增加样本的多样性。
glide-text2im/notebooks/text2im.ipynb
def model_fn(x_t, ts, **kwargs):
"""
带有 CFG 的前向过程
:param x_t: t 步的加噪图像 x_t
:param ts: 时间步 t
:param kwargs: 带有模型条件的字典
:return: CFG 输出
"""
half = x_t[: len(x_t) // 2]
combined = th.cat([half, half], dim=0) # 确保有条件和无条件生成使用相同的输入
model_out = model(combined, ts, **kwargs) # 有条件和无条件生成
eps, rest = model_out[:, :3], model_out[:, 3:] # 模型预测分布参数
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) # 取出 eps 参数的有条件和无条件预测结果
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) # 使用公式计算 CFG 结果
eps = th.cat([half_eps, half_eps], dim=0)
return th.cat([eps, rest], dim=1)

浙公网安备 33010602011771号