VDM
1.1 VDM简介
VDM (Variational Diffusion Models) 基于 MHVAE 模型,但与 MHVAE 模型有3个不同:
-
对于所有时间步 \(t\):隐变量 \(\boldsymbol{z}_t\) 的维度和数据 \(\boldsymbol{x}\) 的维度相等,即 \(\boldsymbol{z}_t \in \mathbb{R}^d\), \(x \in \mathbb{R}^d\);
-
对于所有时间步 \(t\):隐变量 \(\boldsymbol{z}_t\) 不是通过神经网络模型学习得到的,而是以前一个时间步 \(\boldsymbol{z}_{t-1}\) 为均值的高斯分布。所以,Diffusion Models 不需要通过神经网络模型学习一个 Encoder;
-
随着时间步 \(t\) 的增大,隐变量 \(\boldsymbol{z}_t\) 逐渐逼近标准正态分布,最后在第 \(T\) 步时 \(\mathbf{z}_T \sim \mathcal{N}(\mathbf{0},\mathbf{I})\)(T足够大)。
对照 MHVAE 的联合概率公式 ,可得 VDM 的联合概率公式:
\[\begin{align}
\underbrace{p\left (x_{0:T}\right )}_{\text{Joint Distribution}} = \underbrace{p\left (x_T\right )}_{\text{Prior}} \prod_{t=1}^{T} \underbrace{p_{\theta}\left (x_{t-1}\mid x_t\right )}_{\text{Decoder}}
\end{align}\]
对照MHVAE 的后验公式,可得VDM的后验公式:
\[\begin{align}
\underbrace{q\left (\boldsymbol{x}_{1:T}\mid \boldsymbol{x}_0\right )}_{\text{Posterior Distribution}} = \prod_{t=1}^{T} \underbrace{q\left (\boldsymbol{x}_t\mid \boldsymbol{x}_{t-1}\right )}_{\text{Encoder}}
\end{align}\]
注意MHVAE 的公式和VDM的公式有以下区别:
- \(q_{\phi}\) 全部修改为 \(q\) ,因为VDM模型的Encoder不需要用神经网络建模;
- \(\boldsymbol{z}_t\) 全部修改为 \(\boldsymbol{x}_t\),因为在VDM中 \(\boldsymbol{z}_t\) 的维度和 \(\boldsymbol{x}_t\) 的维度相等。
- \(\boldsymbol{x}\) 全部修改为 \(\boldsymbol{x}_0\)
1.2 如何推导VDM的ELBo?
根据MHVAE 的ELBo公式 将 \(q_{\phi}\) 改成 \(q\),\(z_t\) 改成 \(x_t\) 即可得到VDM的ELBo:
\[\begin{align}
& \log \underbrace{p(\boldsymbol{x})}_{\text{Evidence}} \\
\geq & \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_{0: T}\right)}{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\right]}_{\text{ELBo of VDM}} \\
=& \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) \prod_{t=1}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\prod_{t=1}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) \prod_{t=2}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right) \prod_{t=2}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) \prod_{t=2}^T p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right) \prod_{t=2}^T q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\frac{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{\frac{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}+\log \frac{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}+\log \prod_{t=2}^T \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right) p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}+\sum_{t=2}^T \log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]+\mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_{1: T} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\
= & \mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]+\mathbb{E}_{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\left[\log \frac{p\left(\boldsymbol{x}_T\right)}{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t, \boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}\left[\log \frac{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}\right] \\
= & \underbrace{\underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{x_0\approx{x1}}}_{\text{reconstruction term} \color{red}{\approx 0}}-\underbrace{D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}_{\approx N(0, I)} \parallel \underbrace{p\left(\boldsymbol{x}_T\right)}_{=N(0, I)}\right)}_{\text{prior matching term}\color{red}{\approx 0}} -\underbrace{\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[
D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}_{\color{red}{\text {complexity posterior}}} \parallel \underbrace{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}_{\textcolor{green}{\text{Decoder of VDM}}}\right)
\right]}_{\text {denoising matching term }}}_{\textbf{Objective function to optimize}}
\end{align}
\]
1.3 如何从ELBo推导VDM的目标函数?
1.3.1 重建项 (reconstruction term)
结论:可以通过蒙特卡罗估计计算,但是真实情况是当 \(T\) 比较大时 \(\boldsymbol{x}_0 \approx \boldsymbol{x}_1\),可忽略不计。
\[\begin{align}
\underbrace{\underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{x_0\approx{x1}}}_{\text{reconstruction term}} \approx 0
\end{align}\]
1.3.2 先验匹配项 (prior matching term)
在 VDM 中,我们有如下假设:
\[\begin{align}
\underbrace{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1} \right)}_{\textcolor{red}{\text{Encoder of VDM}}} & =\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\alpha_t} \boldsymbol{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right) \quad \text{ with } \boldsymbol{\alpha_{t} \in (0, 1)}
\end{align}\]
利用重参数化技巧,可得:
\[\begin{align}
\boldsymbol{x}_t & =\sqrt{\alpha_t} \boldsymbol{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol{\epsilon} \quad \text { with } \boldsymbol{\epsilon} \sim \mathcal{N}(\boldsymbol{\epsilon} ; \mathbf{0}, \mathbf{I}) \text{, }\boldsymbol{\alpha_{t} \in (0, 1)}
\end{align}\]
基于重参数化技巧继续推导,可得:
\[\begin{align}
\boldsymbol{x}_t & =\sqrt{\alpha_t} \boldsymbol{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\
& =\sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_{t-2}^*\right)+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\
& =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}} \boldsymbol{\epsilon}_{t-2}^*+\sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}^* \\
& =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{{\sqrt{\alpha_t-\alpha_t \alpha_{t-1}}}^2+\sqrt{1-\alpha_t}} \boldsymbol{\epsilon}_{t-2} \quad \text{(apply } \boldsymbol{\lbrace \epsilon_t^*,\epsilon_t \rbrace_{t=0}^{T}\overset{iid}{\sim}\mathcal{N}\boldsymbol{(\epsilon; \mathbf{0}, \mathbf{I})}}\text{)}\\
& =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}+1-\alpha_t} \boldsymbol{\epsilon}_{t-2} \\
& =\sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \boldsymbol{\epsilon}_{t-2} \\
& =\ldots \\
& =\sqrt{\prod_{i=1}^t \alpha_i} \boldsymbol{x}_0+\sqrt{1-\prod_{i=1}^t \alpha_i \boldsymbol{\epsilon}_0} \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \\
& =\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_0 \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \\
& \sim \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)
\end{align}\]
可得:
\[q(\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)
\]
由马尔可夫性:
\[\begin{align}
q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) = q(\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)
\end{align}\]
当 \(T\) 足够大,比如 \(T = 1000\) :
\[\alpha_t \in (0, 1) \implies \bar{\alpha}_T \approx 0
\]
所以:
\[\boldsymbol{q}(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \approx \mathcal{N}\left(\mathbf{0}, \mathbf{I}\right)
\]
结论:先验匹配项可以忽略不计:
\[\underbrace{D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right)}_{\approx N(0, I)} \parallel \underbrace{p\left(\boldsymbol{x}_T\right)}_{=N(0, I)}\right)}_{\text{prior matching term}} \approx 0
\]
1.3.3 降噪匹配项 (denoising matching term)
下面对降噪匹配项中的子式分别进行推导:
\[\begin{align}
& \underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) }_{\text{complexity posterior}} \\
= & \frac{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0\right) q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0\right)}{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \\
= & \frac{\overbrace{\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\alpha_t} \boldsymbol{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right)}^{\text{apply Markov Property in Eq.(25)}} \overbrace{\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0,\left(1-\bar{\alpha}_{t-1}\right) \mathbf{I}\right)}^{\text{apply Eq.(37)}}}{\underbrace{\mathcal{N}\left(\boldsymbol{x}_t ; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)}_{\text{apply Eq.(37)}}} \\
\propto & \exp \left\{-\left[\frac{\left(\boldsymbol{x}_t-\sqrt{\alpha_t} \boldsymbol{x}_{t-1}\right)^2}{2\left(1-\alpha_t\right)}+\frac{\left(\boldsymbol{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0\right)^2}{2\left(1-\bar{\alpha}_{t-1}\right)}-\frac{\left(\boldsymbol{x}_t-\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0\right)^2}{2\left(1-\bar{\alpha}_t\right)}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\frac{\left(\boldsymbol{x}_t-\sqrt{\alpha_t} \boldsymbol{x}_{t-1}\right)^2}{1-\alpha_t}+\frac{\left(\boldsymbol{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\boldsymbol{x}_t-\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0\right)^2}{1-\bar{\alpha}_t}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\frac{\left(-2 \sqrt{\alpha_t} \boldsymbol{x}_t \boldsymbol{x}_{t-1}+\alpha_t \boldsymbol{x}_{t-1}^2\right)}{1-\alpha_t}+\frac{\left(\boldsymbol{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_{t-1} \boldsymbol{x}_0\right)}{1-\bar{\alpha}_{t-1}}+C\left(\boldsymbol{x}_t, \boldsymbol{x}_0\right)\right]\right\} \\
\propto & \exp \left\{-\frac{1}{2}\left[-\frac{2 \sqrt{\alpha_t} \boldsymbol{x}_t \boldsymbol{x}_{t-1}}{1-\alpha_t}+\frac{\alpha_t \boldsymbol{x}_{t-1}^2}{1-\alpha_t}+\frac{\boldsymbol{x}_{t-1}^2}{1-\bar{\alpha}_{t-1}}-\frac{2 \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_{t-1} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\left(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\frac{\alpha_t\left(1-\bar{\alpha}_{t-1}\right)+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left[\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \boldsymbol{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right) \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right)}{\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}} \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \boldsymbol{x}_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0}{1-\bar{\alpha}_{t-1}}\right)\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \boldsymbol{x}_{t-1}\right]\right\} \\
= & \exp \left\{-\frac{1}{2}\left(\frac{1}{\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}}\right)\left[\boldsymbol{x}_{t-1}^2-2 \frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t} \boldsymbol{x}_{t-1}\right]\right\} \\
\propto & \mathcal{N}(\boldsymbol{x}_{t-1} ; \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t}}_{\mu_q\left(\boldsymbol{x}_t, \boldsymbol{x}_0\right)}, \underbrace{\left.\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{I}\right)}_{\boldsymbol{\Sigma}_q(t)}
\end{align}\]
参考式 (44) ,试图将 \(p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)\) 也改写为正态分布的形式,令方差与式 (44) 相等:
\[\begin{align}
p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) = \mathcal{N}(\boldsymbol{x}_{t-1}; \underbrace{\mu_{\theta}(\boldsymbol{x}_t, t)}_{\text{learned by model}}, \Sigma_{q}(t))
\end{align}\]
参考式 (44) 中的 \(\mu_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\),可得:
\[\begin{align}
\boldsymbol{\mu}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)
=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \overbrace{\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)}^{\text{learned by model}}}{1-\bar{\alpha}_t}
\end{align}\]
2.1 VDM的目标函数
公式 (15) 中的重建项和先验匹配项小到可以忽略不计,可得目标函数:
\[\begin{align}
&\operatorname{arg}\max \log{p(\boldsymbol{x})} \\
\propto &\operatorname{arg}\min_{\theta} \underbrace{\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[
D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}_{\color{red}{\text {complexity posterior}}} \parallel \underbrace{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}_{\textcolor{red}{\text{Decoder of VDM}}}\right)
\right]}_{\text {denoising matching term }}}_{\textbf{Objective function to optimize}}
\end{align}\]
根据公式 (48) ,VDM的目标函数为:
\[\begin{align}
&\operatorname{arg}\min \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\text{KL}}(q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t}, \boldsymbol{x}_{0})\parallel p_{\theta}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t}))\right] \\
=& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\operatorname{KL}}(\mathcal{N}(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0), \boldsymbol{\Sigma}_{q}(t)) \parallel \mathcal{N}(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t}), \boldsymbol{\Sigma}_{q}(t)))\right] \\
=& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}[\log\dfrac{|\boldsymbol{\Sigma}_q(t)|}{|\boldsymbol{\Sigma}_q(t)|}-d+\operatorname{tr}(\boldsymbol{\Sigma}_q(t)^{-1}\boldsymbol{\Sigma}_q(t))+(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))^T\boldsymbol{\Sigma}_q(t)^{-1}(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))]\right] \quad \text{(apply )}\\
=& \operatorname{arg}\min\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}\left[\log1-d+d+(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))^T\sum_q(t)^{-1}(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0))\right]\right] \\
=& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2}\left[\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)^{T}\sum_{q}(t)^{-1}\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)\right]\right] \\
=& \operatorname{arg}\min\limits_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \dfrac{1}{2}\left[\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)^{T}\left(\sigma_{q}^{2}(t)I\right)^{-1}\left(\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t,\boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right)\right]\right] \\
=& \operatorname{arg}\min_{\theta}\sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2\sigma_{q}^{2}(t)}\left[\left\|\boldsymbol{\mu}_{\theta}(\boldsymbol{x}_t, \boldsymbol{t})-\boldsymbol{\mu}_{q}(\boldsymbol{x}_t, \boldsymbol{x}_0)\right\|_{2}^{2}\right] \right] \\
=& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ \frac{1}{2\sigma_{q}^{2}(t)}\left[\left\| \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \overbrace{\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)}^{\text{learned by model}}}{1-\bar{\alpha}_t}}_{\mu_{\theta} \text{ apply Eq.(55)}} - \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \boldsymbol{x}_0}{1-\bar{\alpha}_t}}_{\mu_{q} \text{ apply Eq.(53)}} \right\|_{2}^{2}\right] \right] \\
=& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \left[ \left\| \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})\hat{x}_{\theta}(x_t,t)}{1-\bar{\alpha}_t} - \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})x_0}{1-\bar{\alpha}_t} \right\|_{2}^{2} \right] \right] \\
=& \operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \left[ \left\| \frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})}{1-\bar{\alpha}_t} (\hat{x}_{\theta}(x_t,t) - x_0) \right\|_{2}^{2} \right] \right] \\
=& \underbrace{\operatorname{arg}\min_{\theta} \sum_{t=2}^{T}\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^2}{(1-\bar{\alpha}_t)^2} \left[ \left\| \hat{x}_{\theta}(x_t,t) - x_0 \right\|_{2}^{2} \right] \right]}_{\text{Objective function of Diffusion Model}}
\end{align}\]
采用蒙特卡洛估计,可得:
\[\begin{align}
\operatorname{arg}\min_{\theta} \mathbb{E}_{t\sim \mathbf{U}(2,T)} \left [ \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac{1}{2\sigma_{q}^{2}(t)} \frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^2}{(1-\bar{\alpha}_t)^2} \left[ \left\| \hat{x}_{\theta}(x_t,t) - x_0 \right\|_{2}^{2} \right] \right] \right]
\end{align}\]
综上所述,VDM模型学习的预测原图 \(\boldsymbol{x}_0\),因为目标函数中的 \(\|\hat{\boldsymbol{x}}_{\theta}(\boldsymbol{x}_t, t) - \boldsymbol{x}_0\|_{2}^{2}\) 。