变分扩散模型(VDM)

VDM

1.1 VDM简介

VDM (Variational Diffusion Models) 基于 MHVAE 模型,但与 MHVAE 模型有3个不同:

  • 对于所有时间步 \(t\):隐变量 \(\boldsymbol{z}_t\) 的维度和数据 \(\boldsymbol{x}\) 的维度相等,即 \(\boldsymbol{z}_t \in \mathbb{R}^d\), \(x \in \mathbb{R}^d\)

  • 对于所有时间步 \(t\):隐变量 \(\boldsymbol{z}_t\) 不是通过神经网络模型学习得到的,而是以前一个时间步 \(\boldsymbol{z}_{t-1}\) 为均值的高斯分布。所以,Diffusion Models 不需要通过神经网络模型学习一个 Encoder;

  • 随着时间步 \(t\) 的增大,隐变量 \(\boldsymbol{z}_t\) 逐渐逼近标准正态分布,最后在第 \(T\) 步时 \(\mathbf{z}_T \sim \mathcal{N}(\mathbf{0},\mathbf{I})\)(T足够大)。

对照 MHVAE 的联合概率公式 ,可得 VDM 的联合概率公式:

\[\begin{align} \underbrace{p\left (x_{0:T}\right )}_{\text{Joint Distribution}} = \underbrace{p\left (x_T\right )}_{\text{Prior}} \prod_{t=1}^{T} \underbrace{p_{\theta}\left (x_{t-1}\mid x_t\right )}_{\text{Decoder}} \end{align}\]

对照MHVAE 的后验公式,可得VDM的后验公式:

\[\begin{align} \underbrace{q\left (\boldsymbol{x}_{1:T}\mid \boldsymbol{x}_0\right )}_{\text{Posterior Distribution}} = \prod_{t=1}^{T} \underbrace{q\left (\boldsymbol{x}_t\mid \boldsymbol{x}_{t-1}\right )}_{\text{Encoder}} \end{align}\]

注意MHVAE 的公式和VDM的公式有以下区别:

  • \(q_{\phi}\) 全部修改为 \(q\) ,因为VDM模型的Encoder不需要用神经网络建模;
  • \(\boldsymbol{z}_t\) 全部修改为 \(\boldsymbol{x}_t\),因为在VDM中 \(\boldsymbol{z}_t\) 的维度和 \(\boldsymbol{x}_t\) 的维度相等。
  • \(\boldsymbol{x}\) 全部修改为 \(\boldsymbol{x}_0\)

1.2 如何推导VDM的ELBo?

根据MHVAE 的ELBo公式 将 \(q_{\phi}\) 改成 \(q\)\(z_t\) 改成 \(x_t\) 即可得到VDM的ELBo:

\[\begin{align} & \log \underbrace{p(\boldsymbol{x})}_{\text{Evidence}} \\ \geq & \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_{0: T}\right)}{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\right]}_{\text{ELBo of VDM}} \\ =& \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) \prod_{t=1}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\prod_{t=1}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) \prod_{t=2}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right) \prod_{t=2}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) \prod_{t=2}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right) \prod_{t=2}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\frac{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\frac{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \frac{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}+\sum_{t=2}^T \log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]+\mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\ = & \mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]+\mathbb{E}_{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t, \boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\ = & \underbrace{\underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{x_0\approx{x1}}}_{\text{reconstruction term} \color{red}{\approx 0}}-\underbrace{D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}_{\approx N(0, I)} \parallel \underbrace{p\left(\boldsymbol{x}_T\right)}_{=N(0, I)}\right)}_{\text{prior matching term}\color{red}{\approx 0}} -\underbrace{\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}_{\color{red}{\text {complexity posterior}}} \parallel \underbrace{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}_{\textcolor{green}{\text{Decoder of VDM}}}\right) \right]}_{\text {denoising matching term }}}_{\textbf{Objective function to optimize}} \end{align} \]

1.3 如何从ELBo推导VDM的目标函数?

1.3.1 重建项 (reconstruction term)

结论:可以通过蒙特卡罗估计计算,但是真实情况是当 \(T\) 比较大时 \(\boldsymbol{x}_0 \approx \boldsymbol{x}_1\),可忽略不计。

\[\begin{align} \underbrace{\underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{x_0\approx{x1}}}_{\text{reconstruction term}} \approx 0 \end{align}\]

1.3.2 先验匹配项 (prior matching term)

在 VDM 中,我们有如下假设:

\[\begin{align} \underbrace{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1} \right)}_{\textcolor{red}{\text{Encoder of VDM}}} & =\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\alpha_t} \boldsymbol{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right) \quad \text{ with } \boldsymbol{\alpha_{t} \in (0, 1)} \end{align}\]

利用重参数化技巧,可得:

\[\begin{align} \boldsymbol{x}_t & =\sqrt{\alpha_t} \boldsymbol{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol{\epsilon} \quad \text { with } \boldsymbol{\epsilon} \sim \mathcal{N}(\boldsymbol{\epsilon} ; \mathbf{0}, \mathbf{I}) \text{, }\boldsymbol{\alpha_{t} \in (0, 1)} \end{align}\]

基于重参数化技巧继续推导,可得:

\[\begin{align} \boldsymbol{x}_t & =\sqrt{\alpha_t} \boldsymbol{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\ & =\sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_{t-2}^*\right)+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\ & =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}} \boldsymbol{\epsilon}_{t-2}^*+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\ & =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{{\sqrt{\alpha_t-\alpha_t \alpha_{t-1}}}^2+\sqrt{1-\alpha_t}} \boldsymbol{\epsilon}_{t-2} \quad \text{(apply } \boldsymbol{\lbrace \epsilon_t^*,\epsilon_t \rbrace_{t=0}^{T}\overset{iid}{\sim}\mathcal{N}\boldsymbol{(\epsilon; \mathbf{0}, \mathbf{I})}}\text{)}\\ & =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}+1-\alpha_t} \boldsymbol{\epsilon}_{t-2} \\ & =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \boldsymbol{\epsilon}_{t-2} \\ & =\ldots \\ & =\sqrt{\prod_{i=1}^t \alpha_i} \boldsymbol{x}_0+\sqrt{1-\prod_{i=1}^t \alpha_i \boldsymbol{\epsilon}_0} \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \\ & =\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_0 \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \\ & \sim \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right) \end{align}\]

可得:

\[q(\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right) \]

由马尔可夫性:

\[\begin{align} q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) = q(\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right) \end{align}\]

\(T\) 足够大,比如 \(T = 1000\)

\[\alpha_t \in (0, 1) \implies \bar{\alpha}_T \approx 0 \]

所以:

\[\boldsymbol{q}(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \approx \mathcal{N}\left(\mathbf{0}, \mathbf{I}\right) \]

结论:先验匹配项可以忽略不计:

\[\underbrace{D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}_{\approx N(0, I)} \parallel \underbrace{p\left(\boldsymbol{x}_T\right)}_{=N(0, I)}\right)}_{\text{prior matching term}} \approx 0 \]

1.3.3 降噪匹配项 (denoising matching term)

下面对降噪匹配项中的子式分别进行推导:

\[\begin{align} & \underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) }_{\text{complexity posterior}} \\ = & \frac{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \\ = & \frac{\overbrace{\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\alpha_t} \boldsymbol{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right)}^{\text{apply Markov Property in Eq.(25)}} \overbrace{\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0,\left(1-\bar{\alpha}_{t-1}\right) \mathbf{I}\right)}^{\text{apply Eq.(37)}}}{\underbrace{\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)}_{\text{apply Eq.(37)}}} \\ \propto & \exp \left\{-\left[\frac{\left(\boldsymbol{x}_t-\sqrt{\alpha_t} \boldsymbol{x}_{t-1}\right)^2}{2\left(1-\alpha_t\right)}+\frac{\left(\boldsymbol{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0\right)^2}{2\left(1-\bar{\alpha}_{t-1}\right)}-\frac{\left(\boldsymbol{x}_t-\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0\right)^2}{2\left(1-\bar{\alpha}_t\right)}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\frac{\left(\boldsymbol{x}_t-\sqrt{\alpha_t} \boldsymbol{x}_{t-1}\right)^2}{1-\alpha_t}+\frac{\left(\boldsymbol{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\boldsymbol{x}_t-\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0\right)^2}{1-\bar{\alpha}_t}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\frac{\left(-2 \sqrt{\alpha_t} \boldsymbol{x}_t \boldsymbol{x}_{t-1}+\alpha_t \boldsymbol{x}_{t-1}^2\right)}{1-\alpha_t}+\frac{\left(\boldsymbol{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_{t-1} \boldsymbol{x}_0\right)}{1-\bar{\alpha}_{t-1}}+C\left(\boldsymbol{x}_t, \boldsymbol{x}_0\right)\right]\right\} \\ \propto & \exp \left\{-\frac{1}{2}\left[-\frac{2 \sqrt{\alpha_t} \boldsymbol{x}_t \boldsymbol{x}_{t-1}}{1-\alpha_t}+\frac{\alpha_t \boldsymbol{x}_{t-1}^2}{1-\alpha_t}+\frac{\boldsymbol{x}_{t-1}^2}{1-\bar{\alpha}_{t-1}}-\frac{2 \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_{t-1} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\left(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\frac{\alpha_t\left(1-\bar{\alpha}_{t-1}\right)+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left[\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right)}{\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}} \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right)\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \boldsymbol{x}_{t-1}\right]\right\} \\ = & \exp \left\{-\frac{1}{2}\left(\frac{1}{\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t} \boldsymbol{x}_{t-1}\right]\right\} \\ \propto & \mathcal{N}(\boldsymbol{x}_{t-1} ; \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t}}_{\mu_q\left(\boldsymbol{x}_t, \boldsymbol{x}_0\right)}, \underbrace{\left.\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{I}\right)}_{\boldsymbol{\Sigma}_q(t)} \end{align}\]

参考式 (44) ,试图将 \(p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)\) 也改写为正态分布的形式,令方差与式 (44) 相等:

\[\begin{align} p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) = \mathcal{N}(\boldsymbol{x}_{t-1}; \underbrace{\mu_{\theta}(\boldsymbol{x}_t, t)}_{\text{learned by model}}, \Sigma_{q}(t)) \end{align}\]

参考式 (44) 中的 \(\mu_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\),可得:

\[\begin{align} \boldsymbol{\mu}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right) =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \overbrace{\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)}^{\text{learned by model}}}{1-\bar{\alpha}_t} \end{align}\]

2.1 VDM的目标函数

公式 (15) 中的重建项和先验匹配项小到可以忽略不计,可得目标函数:

\[\begin{align} &\operatorname{arg}\max \log{p(\boldsymbol{x})} \\ \propto &\operatorname{arg}\min_{\theta} \underbrace{\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}_{\color{red}{\text {complexity posterior}}} \parallel \underbrace{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}_{\textcolor{red}{\text{Decoder of VDM}}}\right) \right]}_{\text {denoising matching term }}}_{\textbf{Objective function to optimize}} \end{align}\]

根据公式 (48) ,VDM的目标函数为:

\[\begin{align} &\operatorname{arg}\min \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\text{KL}}(q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t}, \boldsymbol{x}_{0})\parallel p_{\theta}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t}))\right] \\ =& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\operatorname{KL}}(\mathcal{N}(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0), \boldsymbol{\Sigma}_{q}(t)) \parallel \mathcal{N}(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t}), \boldsymbol{\Sigma}_{q}(t)))\right] \\ =& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}[\log\dfrac{|\boldsymbol{\Sigma}_q(t)|}{|\boldsymbol{\Sigma}_q(t)|}-d+\operatorname{tr}(\boldsymbol{\Sigma}_q(t)^{-1}\boldsymbol{\Sigma}_q(t))+(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))^T\boldsymbol{\Sigma}_q(t)^{-1}(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))]\right] \quad \text{(apply )}\\ =& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}\left[\log1-d+d+(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))^T\sum_q(t)^{-1}(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))\right]\right] \\ =& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2}\left[\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)^{T}\sum_{q}(t)^{-1}\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)\right]\right] \\ =& \operatorname{arg}\min\limits_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}\left[\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)^{T}\left(\sigma_{q}^{2}(t)I\right)^{-1}\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)\right]\right] \\ =& \operatorname{arg}\min_{\theta}\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2\sigma_{q}^{2}(t)}\left[\left\|\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right\|_{2}^{2}\right] \right] \\ =& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2\sigma_{q}^{2}(t)}\left[\left\| \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \overbrace{\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)}^{\text{learned by model}}}{1-\bar{\alpha}_t}}_{\mu_{\theta} \text{ apply Eq.(55)}} - \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t}}_{\mu_{q} \text{ apply Eq.(53)}} \right\|_{2}^{2}\right] \right] \\ =& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \left[ \left\| \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})\hat{x}_{\theta}(x_t,t)}{1-\bar{\alpha}_t} - \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})x_0}{1-\bar{\alpha}_t} \right\|_{2}^{2} \right] \right] \\ =& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \left[ \left\| \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})}{1-\bar{\alpha}_t} (\hat{x}_{\theta}(x_t,t) - x_0) \right\|_{2}^{2} \right] \right] \\ =& \underbrace{\operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^2}{(1-\bar{\alpha}_t)^2} \left[ \left\| \hat{x}_{\theta}(x_t,t) - x_0 \right\|_{2}^{2} \right] \right]}_{\text{Objective function of Diffusion Model}} \end{align}\]

采用蒙特卡洛估计,可得:

\[\begin{align} \operatorname{arg}\min_{\theta} \mathbb{E}_{t\sim \mathbf{U}(2,T)} \left [ \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^2}{(1-\bar{\alpha}_t)^2} \left[ \left\| \hat{x}_{\theta}(x_t,t) - x_0 \right\|_{2}^{2} \right] \right] \right] \end{align}\]

综上所述,VDM模型学习的预测原图 \(\boldsymbol{x}_0\),因为目标函数中的 \(\|\hat{\boldsymbol{x}}_{\theta}(\boldsymbol{x}_t, t) - \boldsymbol{x}_0\|_{2}^{2}\)

posted @ 2024-04-27 00:28  RenjieW  阅读(400)  评论(0)    收藏  举报