Denoising Diffusion Implicit Models(去噪隐式模型)
DDPM有一个很麻烦的问题,就是需要迭代很多步,十分耗时。有人提出了一些方法,比如one-step dm等等。较著名、也比较早的是DDIM。
原文:https://arxiv.org/pdf/2010.02502
参考博文:https://zhuanlan.zhihu.com/p/666552214?utm_id=0
训练过程与ddpm一致,推理过程发生变化,加速了扩散过程,结果也变得稳定一些。
DDIM假设

DM假设
ddim给出了一个新的扩散假设,结合ddpm的原假设,直接往新假设代入xt得到:
根据原假设,联系上式:
得到:

DDIM假设被变形为:

x0可以根据扩散模型假设消去,得到:

当然你可以隔着很多步,所以有:

DDIM代码如下:
#ddpm def sample_backward_step(self, x_t, t, net, simple_var=True,isUnsqueeze=True): n = x_t.shape[0] if isUnsqueeze: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device).unsqueeze(1) else: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device) eps = net(x_t, t_tensor) if simple_var: var = self.betas[t] else: var = (1 - self.alpha_bars[t - 1]) / ( 1 - self.alpha_bars[t]) * self.betas[t] noise = torch.randn_like(x_t) noise *= torch.sqrt(var) mean = (x_t - (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) * eps) / torch.sqrt(self.alphas[t]) x_t = mean + noise return x_t #ddim def time_backward_step(self, x_t, t, net, sample_step=5,isUnsqueeze=True): n = x_t.shape[0] if isUnsqueeze: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device).unsqueeze(1) else: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device) eps = net(x_t, t_tensor) xstar=torch.sqrt(1./self.alpha_bars[t])*x_t-torch.sqrt(1./self.alpha_bars[t]-1)*eps xstar=torch.clamp(xstar,-1,1) prev_t=t-sample_step if t-sample_step>0 else 0 pred_xt=torch.sqrt(1-self.alpha_bars[prev_t])*eps x_prev=torch.sqrt(self.alpha_bars[prev_t])*xstar+pred_xt return x_prev

DDIM结果图

DDPM结果图
ddim inverse:

Null-text Inversion for Editing Real Images using Guided Diffusion Models
def ddim_inverse(self, x_t, t, net, label,sample_step=5,w=10,isUnsqueeze=True): n = x_t.shape[0] if isUnsqueeze: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device).unsqueeze(1) else: t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device) cat = net(x_t, t_tensor, y=torch.ones_like(label) * 1) eps = (w + 1) * (cat) - w * un next_t = t + sample_step if t + sample_step < self.n_steps-1 else self.n_steps-1 xstar=torch.sqrt(self.alpha_bars[next_t]/self.alpha_bars[t])*(x_t-torch.sqrt(1-self.alpha_bars[t])*eps) pred_xt=torch.sqrt(1-self.alpha_bars[next_t])*eps x_next=xstar+pred_xt return x_next

浙公网安备 33010602011771号