[学习笔记]DDPM图片降噪

之前学习了一下流匹配的最简单的demo,前两天学长讲了一下扩散模型的简单应用,就尽快学一下然后把博客写上来。

先贴一下仓库

ajmddzp/flow-matching-ddmp

然后是学长的博客

生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼 - 科学空间|Scientific Spaces

代码也是由学长提供的,强无敌的学长

(一个看起来无比古老的博客)

那么接下来,进入正题,开始打这个boss——扩散模型

0.什么是扩散模型

generated_samples

 

generation_progress

学长的一个比喻非常形象:拆楼与盖楼。

训练的过程就是把楼给拆了

生成的过程就是把楼搭好喽

这里的拆楼,就是在清晰的图片上加噪声

这里的搭楼,就是把雪花图上的噪声消除。

至于为什么叫扩散,因为它的过程很像把一滴墨滴到水中,扩散完的灰色墨水往回逆推。

反正我觉得这个名字还是比较形象的。

1.扩散模型的理论基础

这个部分学长的博客里讲得非常明白,大概就是利用高斯分布的叠加性。至于没法通过观测获得的那些量,那就利用强大的神经网络来进行拟合。

其次,这大概是我第一次真正接触到U-Net这个东西,本来它只存在于人工智能与机器视觉的课堂上,这下它真的走到了工程里

U-Net形如其名,网络像一个U形,因为直接添加了从浅层到深层的连接。

众所周知,深层的网络会包含更加复杂的特征,感受野较大但无法注意到细节;

浅层的网络会注意到一些细节的信息,比如边缘、角点之类

如果直接用MLP一路算下去,就算还原了图片也只能看到一个平滑的光杆(笑

U-Net开辟了一条直接从浅层网络指向深层网络的路径,很好地解决了这个问题。

实现方法:

1 # 上采样(带跳跃连接)
2         x = self.up1(torch.cat([x, x3], dim=1), t)  # [B, 128, 14, 14]
3         # cat:连接矩阵,dim=1表示按列连接,dim=0表示按行连接
4         # 拼接特征,达到连接的效果
5         x = self.up2(torch.cat([x, x2], dim=1), t)  # [B, 64, 28, 28]
6         x = self.up3(torch.cat([x, x1], dim=1), t)  # [B, 64, 28, 28]    
View Code

也是比较直白。

2.扩散模型的实现方法

还是直接贴代码吧,我感觉注释挺详细了。

不过扩散模型的数学基础有点小复杂,跟代码对应起来有点困难。

不过基本上都是按照数学推导一步一步来的

ddpm_model.py:

  1 import torch
  2 import torch.nn as nn
  3 import math
  4 
  5 
  6 class SinusoidalPositionalEmbedding(nn.Module):
  7     """时间步的位置编码"""
  8     # 具有线性性质,并且不重复
  9 
 10     def __init__(self, dim):
 11         super().__init__()
 12         self.dim = dim
 13 
 14     def forward(self, time):
 15         device = time.device  # 获取张量所在的设备,gpu计算步骤
 16         half_dim = self.dim // 2
 17         emb = math.log(10000) / (half_dim - 1)
 18         emb = torch.exp(torch.arange(
 19             half_dim, device=device, dtype=torch.float32) * -emb)
 20         emb = time[:, None].float() * emb[None, :]
 21         emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
 22         return emb
 23 
 24 
 25 class Block(nn.Module):
 26     """U-Net的基础块"""
 27 
 28     def __init__(self, in_ch, out_ch, time_emb_dim, up=False, down=False):
 29         super().__init__()
 30         self.time_mlp = nn.Linear(time_emb_dim, out_ch)  # 简单的线性全连接层
 31         if up:  # U-Net,上下采样
 32             self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1)  # 转置卷积
 33         elif down:
 34             # stride=2 for downsampling
 35             # 简单的卷积层,in out kernel_size stride padding
 36             self.conv = nn.Conv2d(in_ch, out_ch, 3, 2, 1)
 37         else:
 38             self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
 39         self.norm = nn.GroupNorm(8, out_ch)  # 组归一化,总之就是归一化,作用差不多
 40         self.act = nn.SiLU()  # sigmoid,你终于出现了
 41 
 42     def forward(self, x, t):
 43         # 时间嵌入
 44         t_emb = self.time_mlp(t)[:, :, None, None]
 45         # 卷积
 46         h = self.conv(x)
 47         # 添加时间嵌入
 48         h = h + t_emb
 49         # 妙哉时间嵌入
 50         # 时间条件注入:通过广播加法,给整张图打上“时间底色”。
 51         # 这不是覆盖特征,而是调节特征,让模型根据当前的时间步 t 来处理图像。
 52 
 53         # 归一化和激活
 54         h = self.norm(h)
 55         h = self.act(h)
 56         return h
 57 
 58 
 59 class SimpleUNet(nn.Module):
 60     """简单的U-Net架构用于DDPM"""
 61 
 62     def __init__(self, image_size=28, in_channels=1, time_emb_dim=32):
 63         super().__init__()
 64         self.time_emb_dim = time_emb_dim
 65 
 66         # 时间嵌入
 67         self.time_mlp = nn.Sequential(  # 时间嵌入层:使用 nn.Sequential 封装成一个线性流水线
 68             # 作用:将标量时间步 t 转化为高维向量,经过位置编码 -> 线性变换 -> 激活 -> 线性变换
 69             SinusoidalPositionalEmbedding(time_emb_dim),
 70             nn.Linear(time_emb_dim, time_emb_dim),
 71             nn.SiLU(),
 72             nn.Linear(time_emb_dim, time_emb_dim)
 73         )
 74 
 75         # 下采样路径
 76         self.down1 = Block(in_channels, 64, time_emb_dim)
 77         self.down2 = Block(64, 128, time_emb_dim, down=True)
 78         self.down3 = Block(128, 256, time_emb_dim, down=True)
 79 
 80         # 中间层
 81         self.mid1 = Block(256, 256, time_emb_dim)
 82         self.mid2 = Block(256, 256, time_emb_dim)
 83         # 中间层 (Bottleneck):
 84         # 在最低分辨率(7x7)和最高通道数(256)下处理图像的全局语义信息。
 85         # 这里是连接编码器(Encoder)和解码器(Decoder)的桥梁,负责对提取的深层特征进行转换。
 86 
 87         # 上采样路径
 88         self.up1 = Block(256 + 256, 128, time_emb_dim, up=True)  # 7x7 -> 14x14
 89         self.up2 = Block(128 + 128, 64, time_emb_dim,
 90                          up=True)  # 14x14 -> 28x28
 91         # 28x28 -> 28x28 (不再上采样)
 92         self.up3 = Block(64 + 64, 64, time_emb_dim, up=False)
 93 
 94         # 输出层
 95         self.out = nn.Conv2d(64, in_channels, 1)
 96 
 97     def forward(self, x, timestep):
 98         """
 99         x: [B, C, H, W] 噪声图像
100         timestep: [B] 时间步
101         """
102         # 时间嵌入
103         t = self.time_mlp(timestep)
104 
105         # 下采样
106         x1 = self.down1(x, t)  # [B, 64, 28, 28]
107         x2 = self.down2(x1, t)  # [B, 128, 14, 14]
108         x3 = self.down3(x2, t)  # [B, 256, 7, 7]
109 
110         # 中间层
111         x = self.mid1(x3, t)
112         x = self.mid2(x, t)
113 
114         # 上采样(带跳跃连接)
115         x = self.up1(torch.cat([x, x3], dim=1), t)  # [B, 128, 14, 14]
116         # cat:连接矩阵,dim=1表示按列连接,dim=0表示按行连接
117         x = self.up2(torch.cat([x, x2], dim=1), t)  # [B, 64, 28, 28]
118         x = self.up3(torch.cat([x, x1], dim=1), t)  # [B, 64, 28, 28]
119 
120         # 输出
121         return self.out(x)
122 
123 
124 class DiffusionProcess:
125     """扩散过程"""
126 
127     def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cuda'):
128         # num_timesteps: 扩散步数
129         # beta_start: beta的起始值
130         # beta_end: beta的结束值
131         self.num_timesteps = num_timesteps
132         self.device = device
133 
134         # 线性beta调度
135         self.betas = torch.linspace(
136             beta_start, beta_end, num_timesteps, device=device)
137         # torch.linspace:生成一个等差数列,线性插值
138         self.alphas = 1.0 - self.betas  # 保留多少原图信息,alpha=1-beta
139 
140         self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
141         # torch.cumprod:计算一个张量的累积积
142         # 对于输入张量 x,输出 y 的第 i 个元素为:
143         # # y[i] = x[0] × x[1] × ... × x[i]
144         self.alphas_cumprod_prev = torch.cat(
145             [torch.tensor([1.0], device=device), self.alphas_cumprod[:-1]])
146         # t-1 的alphas_cumprod
147 
148         # 用于采样的参数
149         self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
150         self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
151             1.0 - self.alphas_cumprod)
152         # sqrt(1-alphas_cumprod),数学这一块
153 
154         # 用于去噪的参数
155         self.posterior_variance = self.betas * \
156             (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
157 
158     def q_sample(self, x_start, t, noise=None):
159         """前向扩散过程:添加噪声"""
160         # x_start: 原始图像
161         if noise is None:
162             noise = torch.randn_like(x_start)
163             # randn_like:生成一个与给定张量相同形状的随机张量,并使用给定张量的数据类型和设备,正态分布
164 
165         sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
166         sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(
167             -1, 1, 1, 1)
168 
169         return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
170     # 按照数学模型加噪声
171 
172     def p_sample(self, model, x_t, t):
173         """反向去噪过程:单步采样"""
174         # 预测噪声
175         predicted_noise = model(x_t, t)
176 
177         # 计算系数
178         alpha_t = self.alphas[t].view(-1, 1, 1, 1)
179         alpha_cumprod_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
180         beta_t = self.betas[t].view(-1, 1, 1, 1)
181 
182         # 计算x_0的预测值
183         pred_x_start = (x_t - torch.sqrt(1.0 - alpha_cumprod_t)
184                         * predicted_noise) / torch.sqrt(alpha_cumprod_t)
185         pred_x_start = torch.clamp(pred_x_start, -1.0, 1.0)
186         # torch.clamp:将张量中的元素限制在给定的区间内
187         # 不同于归一化,更像二值化,将超出范围的值截断为边界值
188         # 先预测一个x0,然后根据这个x0预测x_t,相当于锚点
189 
190         # 对于t=0,直接返回预测的x_0
191         if t[0] == 0:
192             return pred_x_start
193 
194         # 计算alpha_cumprod_prev
195         alpha_cumprod_prev_t = self.alphas_cumprod_prev[t].view(-1, 1, 1, 1)
196 
197         # 后验均值的系数
198         posterior_mean_coef1 = torch.sqrt(
199             alpha_cumprod_prev_t) * beta_t / (1.0 - alpha_cumprod_t)
200         posterior_mean_coef2 = torch.sqrt(
201             alpha_t) * (1.0 - alpha_cumprod_prev_t) / (1.0 - alpha_cumprod_t)
202 
203         # 计算后验均值
204         posterior_mean = posterior_mean_coef1 * \
205             pred_x_start + posterior_mean_coef2 * x_t
206 
207         # 计算后验方差
208         posterior_variance = self.posterior_variance[t].view(-1, 1, 1, 1)
209 
210         # 采样
211         noise = torch.randn_like(x_t)
212         # 叠加噪声,补充细节:把“平均的模糊”变成“具体的纹理”
213         return posterior_mean + torch.sqrt(posterior_variance) * noise
214 
215     def p_sample_loop(self, model, shape, num_samples=1):
216         """完整的反向采样过程"""
217         model.eval()
218         with torch.no_grad():
219             # 从纯噪声开始
220             x = torch.randn(num_samples, *shape, device=self.device)
221             # torch.randn:生成一个张量,张量的元素从标准正态分布中随机采样
222 
223             # 逐步去噪
224             for i in reversed(range(self.num_timesteps)):
225                 # reversed:返回一个迭代器,迭代器返回列表或元组中的元素,按从后到前的顺序
226                 t = torch.full((num_samples,), i,
227                                device=self.device, dtype=torch.long)
228                 # torch.full:生成一个张量,张量的元素填充为给定值
229                 # torch.long:表示整数
230                 x = self.p_sample(model, x, t)
231                 # 丢进神经网络,统统丢进神经网络!
232 
233         model.train()
234         return x
View Code

这个模块定义了U-Net、扩散过程

简单介绍一下各个模块:

class SinusoidalPositionalEmbedding(nn.Module):
这个class在forward里定义了一个时间编码器,总之就是把时间步嵌入到向量里,告诉流匹配模型当前到哪个时间步了
class Block(nn.Module):
这个类定义了U-Net网络的基本构成,U-Net由多个略有不同的Block组成
首先是init,定义了各个层
其次是forward,加入了时间嵌入
class SimpleUNet(nn.Module):
这个类具体实现了U-Net
在Block前,又增加了一些层,一些时间嵌入啊、sigmoid啊、线性层之类的
接着就是具体的采样过程了
class DiffusionProcess:
这个类是整个扩散过程与逆扩散过程的实现
在init中,计算了一些后续要用到的量以便查表(比如数学推导中用到的乘法前缀和,以便查表)
q_sample函数即向清晰的图像中加噪声,不同的时间步有不同的噪声量
p_sample即单步去噪过程
p_sample_loop即多次使用单步降噪,获得最终图片

这个文件大概就是这样

接下来是sample.py

  1 import torch
  2 import numpy as np
  3 import matplotlib.pyplot as plt
  4 from ddpm_model import SimpleUNet, DiffusionProcess
  5 
  6 
  7 def sample_and_visualize(
  8     model_path='ddpm_mnist_final.pth',
  9     num_samples=16,
 10     num_timesteps=1000,
 11     device='cuda' if torch.cuda.is_available() else 'cpu',
 12     save_path='generated_samples.png'
 13 ):
 14     """从训练好的模型采样并可视化"""
 15     print(f"使用设备: {device}")
 16 
 17     # 加载模型
 18     model = SimpleUNet(image_size=28, in_channels=1,
 19                        time_emb_dim=32).to(device)
 20     model.load_state_dict(torch.load(model_path, map_location=device))
 21     model.eval()
 22     # 网络初始化
 23 
 24     # 创建扩散过程
 25     diffusion = DiffusionProcess(num_timesteps=num_timesteps, device=device)
 26 
 27     print(f"生成 {num_samples} 个样本...")
 28 
 29     # 采样
 30     samples = diffusion.p_sample_loop(
 31         model, shape=(1, 28, 28), num_samples=num_samples)
 32     # 去噪过程
 33     samples = samples.cpu()
 34     # 把数据搬运到cpu上
 35 
 36     # 反归一化:从[-1, 1]回到[0, 1]
 37     samples = (samples + 1.0) / 2.0
 38     samples = torch.clamp(samples, 0.0, 1.0)
 39     # 先归一化,再截断
 40 
 41     # 可视化
 42     grid_size = int(np.sqrt(num_samples))
 43     fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
 44     axes = axes.flatten()
 45 
 46     for i in range(num_samples):
 47         img = samples[i].squeeze().numpy()
 48         axes[i].imshow(img, cmap='gray')
 49         axes[i].axis('off')
 50 
 51     plt.tight_layout()
 52     plt.savefig(save_path, dpi=150, bbox_inches='tight')
 53     print(f"生成的样本已保存到 {save_path}")
 54 
 55     return samples
 56 
 57 
 58 def sample_with_progress(
 59     model_path='ddpm_mnist_final.pth',
 60     num_samples=16,
 61     num_timesteps=1000,
 62     device='cuda' if torch.cuda.is_available() else 'cpu',
 63     save_path='generation_progress.png'
 64 ):
 65     """可视化生成过程(中间步骤)"""
 66     print(f"使用设备: {device}")
 67 
 68     # 加载模型
 69     model = SimpleUNet(image_size=28, in_channels=1,
 70                        time_emb_dim=32).to(device)
 71     model.load_state_dict(torch.load(model_path, map_location=device))
 72     model.eval()
 73 
 74     # 创建扩散过程
 75     diffusion = DiffusionProcess(num_timesteps=num_timesteps, device=device)
 76 
 77     # 从纯噪声开始
 78     x = torch.randn(1, 1, 28, 28, device=device)
 79 
 80     # 选择要可视化的时间步
 81     timesteps_to_show = [999, 800, 600, 400, 200, 100, 50, 0]
 82     images = []
 83 
 84     with torch.no_grad():
 85         for i in reversed(range(num_timesteps)):
 86             t = torch.full((1,), i, device=device, dtype=torch.long)
 87             x = diffusion.p_sample(model, x, t)
 88 
 89             if i in timesteps_to_show:
 90                 img = x.cpu().squeeze().numpy()
 91                 img = (img + 1.0) / 2.0
 92                 img = np.clip(img, 0.0, 1.0)
 93                 images.append((i, img))
 94 
 95     # 可视化生成过程
 96     fig, axes = plt.subplots(1, len(images), figsize=(15, 3))
 97     for idx, (t, img) in enumerate(images):
 98         axes[idx].imshow(img, cmap='gray')
 99         axes[idx].set_title(f't={t}')
100         axes[idx].axis('off')
101 
102     plt.tight_layout()
103     plt.savefig(save_path, dpi=150, bbox_inches='tight')
104     print(f"生成过程已保存到 {save_path}")
105 
106 
107 if __name__ == '__main__':
108     # 生成样本
109     samples = sample_and_visualize(
110         model_path='ddpm_mnist_final.pth',
111         num_samples=64,
112         save_path='generated_samples.png'
113     )
114 
115     # 可视化生成过程
116     sample_with_progress(
117         model_path='ddpm_mnist_final.pth',
118         save_path='generation_progress.png'
119     )
View Code
def sample_and_visualize
这个函数生成了一批图片,然后一起降噪,并可视化为一个表格(图1所示)
加载模型,生成样本,扩散!
def sample_with_progress
这个函数生成了一张图片,把降噪过程展示了出来(图2所示)
加载模型,生成样本,扩散!

接下来是train.py(其实就没什么好看的了)

  1 import torch
  2 import torch.nn as nn
  3 import torch.optim as optim
  4 from torch.utils.data import DataLoader
  5 from torchvision import datasets, transforms
  6 import matplotlib.pyplot as plt
  7 from ddpm_model import SimpleUNet, DiffusionProcess
  8 
  9 
 10 def train_ddpm(
 11     num_epochs=500,
 12     batch_size=128,
 13     learning_rate=2e-4,
 14     num_timesteps=1000,
 15     device='cuda' if torch.cuda.is_available() else 'cpu'
 16 ):
 17     """训练DDPM模型生成MNIST数字"""
 18     print(f"使用设备: {device}")
 19     
 20     # 数据预处理:归一化到[-1, 1]
 21     transform = transforms.Compose([
 22         transforms.ToTensor(),
 23         transforms.Normalize((0.5,), (0.5,))  # 从[0,1]归一化到[-1,1]
 24     ])
 25     
 26     # 加载MNIST数据集
 27     print("加载MNIST数据集...")
 28     train_dataset = datasets.MNIST(
 29         root='./data',
 30         train=True,
 31         download=True,
 32         transform=transform
 33     )
 34     train_loader = DataLoader(
 35         train_dataset,
 36         batch_size=batch_size,
 37         shuffle=True,
 38         num_workers=2
 39     )
 40     
 41     # 创建模型和扩散过程
 42     model = SimpleUNet(image_size=28, in_channels=1, time_emb_dim=32).to(device)
 43     diffusion = DiffusionProcess(num_timesteps=num_timesteps, device=device)
 44     
 45     # 优化器
 46     optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
 47     criterion = nn.MSELoss()
 48     
 49     # 训练循环
 50     print("开始训练...")
 51     losses = []
 52     
 53     for epoch in range(num_epochs):
 54         epoch_loss = 0.0
 55         num_batches = 0
 56         
 57         for batch_idx, (images, _) in enumerate(train_loader):
 58             images = images.to(device)
 59             
 60             # 随机采样时间步
 61             t = torch.randint(0, num_timesteps, (images.shape[0],), device=device)
 62             
 63             # 采样噪声
 64             noise = torch.randn_like(images)
 65             
 66             # 前向扩散:添加噪声
 67             x_t = diffusion.q_sample(images, t, noise)
 68             
 69             # 预测噪声
 70             predicted_noise = model(x_t, t)
 71             
 72             # 计算损失
 73             loss = criterion(predicted_noise, noise)
 74             
 75             # 反向传播
 76             optimizer.zero_grad()
 77             loss.backward()
 78             # 梯度裁剪
 79             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 80             optimizer.step()
 81             
 82             epoch_loss += loss.item()
 83             num_batches += 1
 84             
 85             if batch_idx % 100 == 0:
 86                 print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.6f}")
 87         
 88         avg_loss = epoch_loss / num_batches
 89         losses.append(avg_loss)
 90         
 91         print(f"Epoch [{epoch+1}/{num_epochs}], 平均损失: {avg_loss:.6f}")
 92         
 93         # 每10个epoch保存一次模型
 94         if (epoch + 1) % 10 == 0:
 95             torch.save(model.state_dict(), f'ddpm_mnist_epoch_{epoch+1}.pth')
 96             print(f"模型已保存到 ddpm_mnist_epoch_{epoch+1}.pth")
 97     
 98     # 保存最终模型
 99     torch.save(model.state_dict(), 'ddpm_mnist_final.pth')
100     print("最终模型已保存到 ddpm_mnist_final.pth")
101     
102     # 绘制损失曲线
103     plt.figure(figsize=(10, 6))
104     plt.plot(losses)
105     plt.xlabel('Epoch')
106     plt.ylabel('Loss')
107     plt.title('Training Loss')
108     plt.grid(True)
109     plt.savefig('training_loss.png')
110     print("损失曲线已保存到 training_loss.png")
111     
112     return model, diffusion
113 
114 
115 if __name__ == '__main__':
116     model, diffusion = train_ddpm(
117         num_epochs=500,
118         batch_size=128,
119         learning_rate=2e-4,
120         num_timesteps=1000
121     )
View Code

3.Q&A

列一些我学习过程中遇到的比较神奇的问题吧,记录一下,也希望能够帮助到一些人入门(坑)

感谢gemini的大力支持没有gemini我死这了

Q1:为什么时间码与前向传播的特征直接就相加了?这样不会导致特征消融或是时间码重复吗?比如两个不同的图片,加上不同的时间码反而相同了?

A1:应该不会,这个时间码更像一种“偏置”,即将整个图片的像素都“变亮”或者“变暗”了固定值。当然,这里是对于提取出的特征的偏执,所以不仅不会重复,还利于广播(反正我是似懂非懂的)

Q2:为什么在单步去噪过程中要直接预测一个x0(即初始图像)呢?

A2:在数学推导里有出现x0,但是我们不知道真正的x0,所以只能用现阶段的模型预测一个,不管它有多离谱,先用着,有比没有强,当作一个锚点来进行指导(依旧似懂非懂,反正工程上存在肯定有它的好处就对了)

翻了一下对着gemini哈气的界面,暂时就找到这两个。那大概就写到这吧

(完)

posted @ 2026-01-07 22:29  阿基米德的澡盆  阅读(6)  评论(0)    收藏  举报