AIGC拾遗:Classifier Guidance
背景
在生成任务中,我们往往需要添加某些先验条件,以得到符合要求的结果,如生成特定类别的图片时,文生图,图生视频时。因此,相较于无条件生成模型,条件生成模型的应用范围广得多。本文回顾了Classier Guidance的方法,该方法额外训练一个分类器(classifier),并采用其梯度进行引导(guidance),在不改变预训练的无条件扩散模型权重的前提下,实现条件生成。此类方法最早源于2021 nips的 Diffusion Models Beat GANs on Image Synthesis。
分类器引导
首先明确一下待解决的问题:已知一个无条件生成模型,实现条件\(y\)下的采样生成。即已知\(p(x_{t-1}|x_{t})\),求\(p(x_{t-1}|x_{t}, y)\)。由贝叶斯公式可得
第二个等号成立基于假设\(p(y|x_{t-1}, x_{t})=p(y|x_{t-1})\)。该假设的合理性在于:\(x_{t-1}\)为\(x_{t}\)的单步去噪版本,当去噪步数足够多时,\(x_{t-1}\)与\(x_{t}\)非常接近;同时,分类任务中额外增加含噪副本直觉上来说并不会增加分类结果;此外,由于加噪过程和条件\(y\)无关,因此有下式
公式\eqref{1}的log形式为
如上文所说,\(x_{t-1}\)与\(x_{t}\)非常接近时,\(\log{p(y|x_{t-1})}\)可以在\(x_{t}\)处进行泰勒展开
将式\eqref{3}代入式\eqref{2},并假设\(p(x_{t-1}|x_{t}) \sim \mathcal{N}(x_{t-1}; \mu(x_{t}), \sigma_{t}^{2}I))\)有
其中,\(z \sim \mathcal{N}(x_{t-1};\mu(x_{t})+\sigma_{t}^{2}\nabla_{x_{t}}\log{p(y|x_{t})}, \sigma_{t}^{2}I)\)。\(C_{1}\)为高斯分布的对数归一化常数项\(-\frac{D}{2}\log{2\pi\sigma_{t}^{2}}\)。\(C_{2}\)为\((\mu(x_{t})-x_{t})\nabla_{x_{t}}\log{p(y|x_{t})}\)\(+\frac{\sigma_{t}^{2}}{2}(\nabla_{x_{t}}\log{p(y|x_{t})})^{2}\),上述两项均为常数。
忽略\(C_{2}\)后,\(x_{t-1}\)可以用下述公式进行采样
梯度缩放
作者指出,在实际使用时,需要缩放分类器的梯度,公式\eqref{5}应改为
\(\gamma\)为缩放因子,增大\(\gamma\)时,生成结果的保真度提高,反之,生成结果的多样性提高。作者对此的解释为:\(\gamma\nabla_{x_{t}}\log{p(y|x_{t})}=\nabla_{x_{t}}\log{p^{\gamma}(y|x_{t})}\),当\(\gamma>1\)时,\(p^{\gamma}(y|x_{t})\)相较于\(p(y|x_{t})\)更加尖锐,能够让模型更加聚焦于分类器的模式。
相关代码
计算\(\gamma\nabla_{x_{t}}\log{p(y|x_{t})}\)
def get_classifer_grad(x_t, t, y):
with torch.enable_grad():
x_in = x_t.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
采样代码
def condition_mean(self, get_classifer_grad, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = get_classifer_grad(x, self._scale_timesteps(t), **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
总结
本文回顾了classifier guidance的方法,该方法能够直接利用预训练的无条件生成模型进行条件生成采样。但该方法存在几个缺点:
-
需要额外训练一个条件加噪分类器。
-
当条件是text或者pose之类的连续变量而并非离散的类别时,无法使用分类器,因此该方法不再适用。
-
部分公式的解释不够清楚,例如公式\eqref{4}中,为什么能忽略\(C_{2}\)的影响以及梯度缩放的合理性。
参考资料
https://proceedings.neurips.cc/paper/2021/hash/49ad23d1ec9fa4bd8d77d02681df5cfa-Abstract.html

浙公网安备 33010602011771号