DDPM

1. DDPM

1.1 从[[VDM]]的目标函数推导DDPM的目标函数

模型学习预测 \(\boldsymbol{x}_0\) 得到的模型是 DDPM ,根据公式 (35) ,可得:

\[\begin{align} \boldsymbol{x}_0=\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\bar{\alpha}_t}} \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \end{align}\]

根据公式 (52) ,可得:

\[\begin{align} \boldsymbol{\mu}_q(\boldsymbol{x}_t,\boldsymbol{x}_0) = & \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\boldsymbol{x}_0}{1-\bar{\alpha}_t} \\ =& \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\bar{\alpha}_t}}}{1-\bar{\alpha}_t} \\ =& \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+(1-\alpha_t)\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\alpha_t}}}{1-\bar{\alpha}_t} \\ =& \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t}{1-\bar{\alpha}_t}+\frac{(1-\alpha_t)\boldsymbol{x}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}-\frac{(1-\alpha_t)\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}} \\ =& \left(\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}+\frac{1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t-\frac{(1-\alpha_t)\sqrt{1-\bar{\alpha}_t}}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ =& \left(\frac{\alpha_t(1-\bar{\alpha}_{t-1})}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}+\frac{1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ =& \frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ =& \frac{1-\bar{\alpha}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ =& \frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \end{align}\]

将模型学习的降噪转移均值 \(\boldsymbol{\mu}_\theta(\boldsymbol{x}_t,t)\) 改写为与公式 (78) 类似的形式:

\[\begin{align} \boldsymbol{\mu}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t) &=\frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t) \\ &=\frac1{\sqrt{\alpha_t}} \left( \boldsymbol{x}_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right) \end{align}\]

结合公式 (56) 可得目标函数:

\[\begin{align} & \arg\min_{\theta} \underbrace{\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[ D_{\mathrm{KL}}\left(\underbrace{q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)}_{\color{red}{\text {complexity posterior}}} \parallel \underbrace{p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)}_{\textcolor{red}{\text{Decoder of [[VDM]]}}}\right) \right]}_{\text {denoising matching term }}}_{\textbf{Objective function to optimize}} \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ D_{\mathrm{KL}}(\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_{q}\left( \boldsymbol{x}_t,\boldsymbol{x}_0\right),\boldsymbol{\Sigma}_{q}\left(t\right)\right)\parallel\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_{\boldsymbol{\theta}}\left (\boldsymbol{x}_t, t \right),\boldsymbol{\Sigma}_{q}\left(t\right)\right)) \right] \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac1{2\sigma_q^2(t)}\left[\left\|\frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\hat{\epsilon}}_\theta(\boldsymbol{x}_t,t)-\frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0\right\|_2^2\right] \right] \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac1{2\sigma_q^2(t)}\left[\left\|\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\hat{\epsilon}_\theta}(\boldsymbol{x}_t,t)\right\|_2^2\right] \right] \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac1{2\sigma_q^2(t)}\left[\left\|\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}(\boldsymbol{\epsilon}_0-\boldsymbol{\hat{\epsilon}_\theta}(\boldsymbol{x}_t,t))\right\|_2^2\right] \right] \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \frac1{2\sigma_q^2(t)}\frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t}\left[\|\boldsymbol{\epsilon}_0-\boldsymbol{\hat{\epsilon}_\theta}(\boldsymbol{x}_t,t)\|_2^2\right] \right] \\ =& \arg\min_{\theta} \sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[ \dfrac{1}{2\sigma_q^2(t)}\dfrac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t}\left[\left\|\epsilon_0-\hat{\epsilon}_\theta(\underbrace{\sqrt{\bar{\alpha}_t}\textbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}_{\text{apply Eq.(35)}}, t)\right\|_2^2\right]\right] \quad \text{ with } \boldsymbol{\epsilon}_{0} \sim \mathcal{N}(\boldsymbol{\epsilon}; \mathbf{0}, \mathbf{I}) \\ =& \arg\min_{\theta} \underbrace{\sum_{t=2}^T \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)} \left[\left\|\epsilon_0-\hat{\epsilon}_\theta(\underbrace{\sqrt{\bar{\alpha}_t}\textbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}_{\text{apply Eq.(35)}}, t)\right\|_2^2\right]}_{\text{Simplified objective function of DDPM}} \end{align}\]

根据公式 (18) 中的子项 \(\left \| \boldsymbol{\epsilon}_0 - \boldsymbol{\hat{\epsilon}_{\theta}}\left( \boldsymbol{x}_t, t\right) \right \|_{2}^{2}\) 可以看出,模型学习的是添加到数据上的源噪声 \(\boldsymbol{\epsilon}_0\)

3.2 DDPM的训练过程

训练过程采用简化后的目标函数,如公式(20)所示,DDPM 训练过程中省去了概率 \(q(x_t \mid x_0)\)

3.3 DDPM的推理过程

说明:\(p_\theta(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_t)=\mathcal{N}(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_\theta(\boldsymbol{x}_t,t),\boldsymbol{\Sigma}_\theta(\boldsymbol{x}_t,t))\)

posted @ 2024-06-11 14:38  RenjieW  阅读(12)  评论(0)    收藏  举报