[学习笔记]Flow Matching with MNIST

一刻也没有为扩散模型感到悲哀,接下来赶到战场的是——流匹配!

[学习笔记]DDPM图片降噪 - 阿基米德的澡盆 - 博客园

(牛马圣体是这样的0.0无聊到看代码写博客)

继续贴仓库

ajmddzp/flow-matching-ddmp

不多废话了,继续写吧

0.什么是流匹配

generated_samples_fm

 

generation_progress_fm

流匹配就是一种很牛逼的方法,用来生成一些东西()

它是一种生成方法,大概原理就是将噪声图像提升到高维,然后利用一个向量场,让高维特征移动,从而达到生成的目的。

通过之前学的一个最最最简单的demo,大概已经知道了流匹配的运用过程

[学习笔记]流匹配(Flow Matching) - 阿基米德的澡盆 - 博客园

(虽然没人看但还是贴个链接)

这个还原MNIST的demo让我对流匹配和生成式模型有了更好的了解

1.流匹配的理论基础

说到流匹配的理论基础,那可是相当的简单直白。

就一个最简单的式子:

$$v = \frac{dx}{dt}$$

一个速度场。就是让噪声点往它该去的地方去。

对比一下扩散模型

如果继续用墨水扩散与复原的例子类比

那么,扩散模型就是对每一个点抽丝剥茧地计算它来自哪里,最有可能去向哪里

流匹配就像是在高维对墨水整体进行受力分析,看穿了墨水滴的变化过程,然后一通操作,吧唧还原了

(潮水啊,我已归来)

这个原理如此之直观与简洁,甚至它的训练步数和采样步数都远远小于扩散模型(DDPM采样1000步,训练500轮;Flow Matching采样100步,训练100轮)

甚至Stable Diffusion3都开始使用流匹配了。

这或许就是数学形式与物理还有代码的完美统一吧。

2.流匹配的实现方法

还是老样子,直接上代码。

哈基米给出的改进代码,其实大致流程与扩散模型差别不大,代码也比DDPM更加好懂。

首先是fm_model.py

  1 import torch
  2 import torch.nn as nn
  3 import math
  4 
  5 # --- 保持不变的基础组件 ---
  6 
  7 
  8 class SinusoidalPositionalEmbedding(nn.Module):
  9     # (保持原代码不变)
 10     def __init__(self, dim):
 11         super().__init__()
 12         self.dim = dim
 13 
 14     def forward(self, time):
 15         device = time.device
 16         half_dim = self.dim // 2
 17         emb = math.log(10000) / (half_dim - 1)
 18         emb = torch.exp(torch.arange(
 19             half_dim, device=device, dtype=torch.float32) * -emb)
 20         # 注意:这里 time 如果是 [0,1] 的小数,嵌入后的频率会很低
 21         # 所以通常建议传入前把 time * 1000,或者在这里乘一个缩放因子
 22         emb = time[:, None].float() * emb[None, :]
 23         emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
 24         return emb
 25 
 26 
 27 class Block(nn.Module):
 28     # (保持原代码不变)
 29     def __init__(self, in_ch, out_ch, time_emb_dim, up=False, down=False):
 30         super().__init__()
 31         self.time_mlp = nn.Linear(time_emb_dim, out_ch)
 32         if up:
 33             self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1)
 34         elif down:
 35             self.conv = nn.Conv2d(in_ch, out_ch, 3, 2, 1)
 36         else:
 37             self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
 38         self.norm = nn.GroupNorm(8, out_ch)
 39         self.act = nn.SiLU()
 40 
 41     def forward(self, x, t):
 42         t_emb = self.time_mlp(t)[:, :, None, None]
 43         h = self.conv(x)
 44         h = h + t_emb  # 这里的逻辑完全一样
 45         h = self.norm(h)
 46         h = self.act(h)
 47         return h
 48 
 49 # --- 略微修改的 UNet ---
 50 
 51 
 52 class SimpleUNet(nn.Module):
 53     """
 54     修改点:
 55     Flow Matching 的时间 t 是连续的 [0, 1] 浮点数。
 56     为了复用原本为整数步设计的 PositionalEmbedding,
 57     我们在 forward 里把 t 放大 (比如 * 1000)。
 58     """
 59 
 60     def __init__(self, image_size=28, in_channels=1, time_emb_dim=32):
 61         super().__init__()
 62         # ... (结构定义保持完全不变,同原代码) ...
 63         self.time_emb_dim = time_emb_dim
 64         self.time_mlp = nn.Sequential(
 65             SinusoidalPositionalEmbedding(time_emb_dim),
 66             nn.Linear(time_emb_dim, time_emb_dim),
 67             nn.SiLU(),
 68             nn.Linear(time_emb_dim, time_emb_dim)
 69         )
 70         self.down1 = Block(in_channels, 64, time_emb_dim)
 71         self.down2 = Block(64, 128, time_emb_dim, down=True)
 72         self.down3 = Block(128, 256, time_emb_dim, down=True)
 73         self.mid1 = Block(256, 256, time_emb_dim)
 74         self.mid2 = Block(256, 256, time_emb_dim)
 75         self.up1 = Block(256 + 256, 128, time_emb_dim, up=True)
 76         self.up2 = Block(128 + 128, 64, time_emb_dim, up=True)
 77         self.up3 = Block(64 + 64, 64, time_emb_dim, up=False)
 78         self.out = nn.Conv2d(64, in_channels, 1)
 79 
 80     def forward(self, x, t):
 81         # 核心修改:如果是 FM,t 是 [0, 1] 的 float
 82         # 放大 1000 倍以适应 Positional Embedding 的频率敏感度
 83         t = t * 1000
 84 
 85         t_emb = self.time_mlp(t)
 86 
 87         # ... (后续 forward 逻辑保持完全不变) ...
 88         x1 = self.down1(x, t_emb)
 89         x2 = self.down2(x1, t_emb)
 90         x3 = self.down3(x2, t_emb)
 91         x = self.mid1(x3, t_emb)
 92         x = self.mid2(x, t_emb)
 93         x = self.up1(torch.cat([x, x3], dim=1), t_emb)
 94         x = self.up2(torch.cat([x, x2], dim=1), t_emb)
 95         x = self.up3(torch.cat([x, x1], dim=1), t_emb)
 96         return self.out(x)
 97 
 98 # --- 全新的 Flow Matching 处理器 ---
 99 
100 
101 class FlowMatchingProcess:
102     """
103     流匹配过程 (Optimal Transport Conditional Flow Matching)
104     也就是 Rectified Flow 的简化版
105     """
106 
107     def __init__(self, device='cuda'):
108         self.device = device
109         # FM 不需要 beta, alpha 这些复杂的 schedule
110         # 它只需要简单的线性插值
111 
112     def get_train_tuple(self, x_start):
113         """
114         构造训练数据
115         x_start (x0): 真实数据 (Data)
116         x_1 (x1): 纯噪声 (Noise)
117         t: 时间 [0, 1]
118         """
119         # 1. 随机采样时间 t ~ Uniform[0, 1]
120         b = x_start.shape[0]
121         t = torch.rand(b, device=self.device)
122         # 随机采样时间
123 
124         # 2. 生成目标噪声 x1
125         x_1 = torch.randn_like(x_start)
126 
127         # 3. 线性插值构建 x_t (Straight Line)
128         # 路径公式:x_t = (1 - t) * x_0 + t * x_1
129         # t=0 是原图,t=1 是噪声 (这里采用了 OT-CFM 的标准定义,也可以反过来)
130         # 为了方便,这里我们定义 t=0 为 Data, t=1 为 Noise
131         # 注意:这个路径必须广播 t 的维度
132         t_expand = t.view(-1, 1, 1, 1)
133         x_t = (1 - t_expand) * x_start + t_expand * x_1
134         # 线性插值
135 
136         # 4. 计算目标速度 (Target Velocity)
137         # 对 x_t 求导 dx_t/dt = x_1 - x_0
138         # 模型需要预测的就是这个“向量场”
139         target_v = x_1 - x_start
140 
141         return x_t, t, target_v
142 
143     @torch.no_grad()
144     # 装饰器,不计算梯度
145     def sample(self, model, shape, steps=50):
146         """
147         采样过程:求解 ODE
148         使用欧拉法 (Euler Method) 从 t=1 (噪声) 走到 t=0 (数据)
149         """
150         model.eval()
151         b = shape[0]
152 
153         # 1. 从纯噪声开始 (t=1)
154         x = torch.randn(shape, device=self.device)
155 
156         # 2. 构造时间步 (从 1.0 降到 0.0)
157         # linspace 生成如 [1.0, 0.98, ..., 0.0]
158         time_steps = torch.linspace(1.0, 0.0, steps + 1, device=self.device)
159         dt = 1.0 / steps  # 步长
160 
161         for i in range(steps):
162             # 当前时间 t
163             t_current = time_steps[i]
164             # 扩展成 batch 维度
165             t_batch = torch.full((b,), t_current, device=self.device)
166 
167             # 3. 模型预测速度 v_pred
168             # 预测的是 dx/dt
169             v_pred = model(x, t_batch)
170 
171             # 4. 欧拉法更新 (Euler Step)
172             # x_{t-1} = x_t - v * dt
173             # 因为我们在倒着走 (从1到0),所以是减去
174             x = x - v_pred * dt
175 
176         model.train()
177         return x

View Code

解释一下各个模块都在干什么

class SinusoidalPositionalEmbedding(nn.Module):
依旧是时间编码,与DDPM一点变化都没有
class Block(nn.Module):
U-Net组件模块,依旧一点变化都没有
class SimpleUNet(nn.Module):
首先是init函数,没有变化
其次是forward函数:对应修改了时间码,就是在原先基础上*了1000
class FlowMatchingProcess:
实现流匹配的类
首先是init,极度简单,指定设备。我极度怀疑它是否真的有存在的必要
接下来get_train_tuple是采样训练数据,也很简单,就是随机叠加一个噪声,然后线性插值,在清晰图片与噪声中间平滑叠加
再接下来,是反向采样sample。就是很直观,预测一步,反向走一步

真的,流匹配的代码易懂程度很高,甚至说得上优美。

接下来是采样sample.py

  1 import torch
  2 import numpy as np
  3 import matplotlib.pyplot as plt
  4 # 假设你在 ddpm_model.py 中已经添加了 FlowMatchingProcess 类
  5 from ddpm_model import SimpleUNet, FlowMatchingProcess
  6 
  7 
  8 def sample_and_visualize(
  9     model_path='fm_mnist_final.pth',  # 建议修改模型后缀以区分
 10     num_samples=16,
 11     steps=50,  # Flow Matching 不需要 1000 步,50 步通常足够
 12     device='cuda' if torch.cuda.is_available() else 'cpu',
 13     save_path='generated_samples_fm.png'
 14 ):
 15     """从训练好的 Flow Matching 模型采样并可视化"""
 16     print(f"使用设备: {device}")
 17 
 18     # 加载模型
 19     # 注意:这里的模型权重必须是使用 Flow Matching Loss 训练出来的
 20     model = SimpleUNet(image_size=28, in_channels=1,
 21                        time_emb_dim=32).to(device)
 22     model.load_state_dict(torch.load(model_path, map_location=device))
 23     model.eval()
 24 
 25     # 创建 Flow Matching 过程
 26     # FM 不需要像 DDPM 那样在初始化时预计算 alpha/beta 表
 27     flow = FlowMatchingProcess(device=device)
 28 
 29     print(f"使用欧拉法生成 {num_samples} 个样本 (Steps={steps})...")
 30 
 31     # 采样 (调用 ODE 求解器)
 32     # shape: [Batch, Channel, Height, Width]
 33     samples = flow.sample(model, shape=(num_samples, 1, 28, 28), steps=steps)
 34     samples = samples.cpu()
 35 
 36     # 反归一化:从[-1, 1]回到[0, 1] (假设数据处理逻辑没变)
 37     samples = (samples + 1.0) / 2.0
 38     samples = torch.clamp(samples, 0.0, 1.0)
 39 
 40     # 可视化 (这部分代码保持不变)
 41     grid_size = int(np.sqrt(num_samples))
 42     fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
 43     axes = axes.flatten()
 44 
 45     for i in range(num_samples):
 46         img = samples[i].squeeze().numpy()
 47         axes[i].imshow(img, cmap='gray')
 48         axes[i].axis('off')
 49 
 50     plt.tight_layout()
 51     plt.savefig(save_path, dpi=150, bbox_inches='tight')
 52     print(f"生成的样本已保存到 {save_path}")
 53 
 54     return samples
 55 
 56 
 57 def sample_with_progress(
 58     model_path='fm_mnist_final.pth',
 59     steps=50,  # 总步数
 60     device='cuda' if torch.cuda.is_available() else 'cpu',
 61     save_path='generation_progress_fm.png'
 62 ):
 63     """可视化 Flow Matching 的生成轨迹 (ODE 积分过程)"""
 64     print(f"使用设备: {device}")
 65 
 66     # 加载模型
 67     model = SimpleUNet(image_size=28, in_channels=1,
 68                        time_emb_dim=32).to(device)
 69     model.load_state_dict(torch.load(model_path, map_location=device))
 70     model.eval()
 71 
 72     # 初始化
 73     # 纯噪声 (对应 t=1.0)
 74     x = torch.randn(1, 1, 28, 28, device=device)
 75 
 76     # 定义要捕捉的帧 (根据步数索引)
 77     # 例如:第0步(开始), 第10步, ... 第50步(结束)
 78     capture_indices = [0, 5, 10, 20, 30, 40, 49]
 79     images = []
 80 
 81     # 准备时间步 (从 1.0 降到 0.0)
 82     time_steps = torch.linspace(1.0, 0.0, steps + 1, device=device)
 83     dt = 1.0 / steps  # 步长
 84 
 85     print("开始逐步积分...")
 86     with torch.no_grad():
 87         for i in range(steps):
 88             # 1. 记录当前状态 (如果是我们要捕捉的关键帧)
 89             if i in capture_indices:
 90                 t_val = time_steps[i].item()  # 获取当前 t 的数值用于标题
 91                 img = x.cpu().squeeze().numpy()
 92                 img = (img + 1.0) / 2.0
 93                 img = np.clip(img, 0.0, 1.0)
 94                 images.append((f"t={t_val:.2f}", img))
 95 
 96             # 2. 欧拉法更新 (Euler Step)
 97             # 获取当前时间 t
 98             t_current = time_steps[i]
 99             t_batch = torch.full((1,), t_current, device=device)
100 
101             # 预测速度 v (dx/dt)
102             v_pred = model(x, t_batch)
103 
104             # 更新位置: x_{next} = x_{curr} - v * dt
105             # (因为我们在倒退时间,从1到0,所以是减去)
106             x = x - v_pred * dt
107 
108         # 添加最后一步的结果 (t=0.0)
109         img = x.cpu().squeeze().numpy()
110         img = (img + 1.0) / 2.0
111         img = np.clip(img, 0.0, 1.0)
112         images.append(("t=0.00", img))
113 
114     # 可视化生成过程
115     fig, axes = plt.subplots(1, len(images), figsize=(15, 3))
116     for idx, (label, img) in enumerate(images):
117         axes[idx].imshow(img, cmap='gray')
118         axes[idx].set_title(label)
119         axes[idx].axis('off')
120 
121     plt.tight_layout()
122     plt.savefig(save_path, dpi=150, bbox_inches='tight')
123     print(f"生成过程已保存到 {save_path}")
124 
125 
126 if __name__ == '__main__':
127     # 生成样本
128     samples = sample_and_visualize(
129         model_path='fm_mnist_final.pth',
130         num_samples=64,
131         steps=50,  # 推荐使用 20-50 步
132         save_path='generated_samples_fm.png'
133     )
134 
135     # 可视化生成过程
136     sample_with_progress(
137         model_path='fm_mnist_final.pth',
138         steps=50,
139         save_path='generation_progress_fm.png'
140     )
View Code

继续拆解模块

sample_and_visualize
采样一堆图片,然后列表展示
就很简单,非常简单
指定设备、加载权重,把噪声往网络里一丢,完事了??
就是加了一个反归一化,让输出在一定范围内,然后可视化
sample_with_progress
单图多步骤采样
也是,指定设备,指定步长,然后丢给神经网络,完事了
最后可视化
就这么简单,优雅

最后是训练用的train.py

  1 import torch
  2 import torch.nn as nn
  3 import torch.optim as optim
  4 from torch.utils.data import DataLoader
  5 from torchvision import datasets, transforms
  6 import matplotlib.pyplot as plt
  7 
  8 # 假设你在 ddpm_model.py 中已经添加了 FlowMatchingProcess 类
  9 # 如果没有,请参考上一条回复中的代码添加到 ddpm_model.py
 10 from ddpm_model import SimpleUNet, FlowMatchingProcess
 11 
 12 
 13 def train_flow_matching(
 14     num_epochs=100,  # Flow Matching 收敛通常更快,可以适当减少 epoch
 15     batch_size=128,
 16     learning_rate=1e-4,  # 建议稍微降低一点 LR,或者保持 2e-4 也可以
 17     device='cuda' if torch.cuda.is_available() else 'cpu'
 18 ):
 19     """训练 Flow Matching 模型生成 MNIST 数字"""
 20     print(f"使用设备: {device}")
 21 
 22     # 1. 数据预处理
 23     # 保持不变:归一化到 [-1, 1]
 24     transform = transforms.Compose([
 25         transforms.ToTensor(),
 26         transforms.Normalize((0.5,), (0.5,))
 27     ])
 28 
 29     # 加载 MNIST 数据集
 30     print("加载 MNIST 数据集...")
 31     train_dataset = datasets.MNIST(
 32         root='./data',
 33         train=True,
 34         download=True,
 35         transform=transform
 36     )
 37     train_loader = DataLoader(
 38         train_dataset,
 39         batch_size=batch_size,
 40         shuffle=True,
 41         num_workers=2
 42     )
 43 
 44     # 2. 创建模型和 Flow Matching 过程
 45     # 模型结构复用 SimpleUNet,参数无需修改
 46     model = SimpleUNet(image_size=28, in_channels=1,
 47                        time_emb_dim=32).to(device)
 48 
 49     # 实例化 Flow Matching 处理器
 50     # 注意:这里不再需要 num_timesteps 参数,因为时间是连续的 [0, 1]
 51     flow = FlowMatchingProcess(device=device)
 52 
 53     # 优化器
 54     optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
 55     criterion = nn.MSELoss()
 56 
 57     # 训练循环
 58     print("开始 Flow Matching 训练...")
 59     losses = []
 60 
 61     for epoch in range(num_epochs):
 62         epoch_loss = 0.0
 63         num_batches = 0
 64 
 65         for batch_idx, (images, _) in enumerate(train_loader):
 66             # x_start (也就是 x0, 真实数据)
 67             x_start = images.to(device)
 68 
 69             # --- Flow Matching 核心逻辑 ---
 70 
 71             # 1. 获取训练元组 (x_t, t, target_v)
 72             # 这一步替代了 DDPM 中的随机采样 t 和加噪过程
 73             # flow 内部会自动生成噪声 x1,并计算线性插值 x_t
 74             x_t, t, target_v = flow.get_train_tuple(x_start)
 75 
 76             # 2. 模型预测
 77             # 输入:当前插值状态 x_t, 时间 t
 78             # 输出:预测的速度场 v_pred (意图去拟合 target_v)
 79             pred_v = model(x_t, t)
 80 
 81             # 3. 计算损失
 82             # 我们希望模型预测的速度 v_pred 尽可能接近真实方向 (x1 - x0)
 83             loss = criterion(pred_v, target_v)
 84 
 85             # ---------------------------
 86 
 87             # 反向传播
 88             optimizer.zero_grad()
 89             loss.backward()
 90             # 梯度裁剪 (保持这个好习惯)
 91             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 92             optimizer.step()
 93 
 94             epoch_loss += loss.item()
 95             num_batches += 1
 96 
 97             if batch_idx % 100 == 0:
 98                 print(
 99                     f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.6f}")
100 
101         avg_loss = epoch_loss / num_batches
102         losses.append(avg_loss)
103 
104         print(f"Epoch [{epoch+1}/{num_epochs}], 平均损失: {avg_loss:.6f}")
105 
106         # 每10个epoch保存一次模型
107         if (epoch + 1) % 10 == 0:
108             torch.save(model.state_dict(), f'fm_mnist_epoch_{epoch+1}.pth')
109             print(f"模型已保存到 fm_mnist_epoch_{epoch+1}.pth")
110 
111     # 保存最终模型
112     torch.save(model.state_dict(), 'fm_mnist_final.pth')
113     print("最终模型已保存到 fm_mnist_final.pth")
114 
115     # 绘制损失曲线
116     plt.figure(figsize=(10, 6))
117     plt.plot(losses)
118     plt.xlabel('Epoch')
119     plt.ylabel('Loss (MSE between v_pred and x1-x0)')
120     plt.title('Flow Matching Training Loss')
121     plt.grid(True)
122     plt.savefig('fm_training_loss.png')
123     print("损失曲线已保存到 fm_training_loss.png")
124 
125     return model
126 
127 
128 if __name__ == '__main__':
129     # Flow Matching 通常可以用更少的 epoch 达到相同的效果
130     model = train_flow_matching(
131         num_epochs=100,
132         batch_size=128,
133         learning_rate=1e-4
134     )
View Code

训练依旧是一大堆优化器啊啥的,我没有细看,也就步详细展开了

3.Q&A

Q1:盆老师,说是流匹配训练了一个向量场,那为什么在去噪的时候还是简单的减去呢

x = x - v_pred * dt#这不还是把噪声减去吗

A1:这就是流匹配的高明之处了,减去噪声是一个表象,可以看作是高维空间的投影。对于MNIST数据集,可以把每个输入都看作28*28维度的一个点(时间不算在内是因为已经把时间编码加到输入上了),之后,流匹配指导的是这样一个高维空间的点的“移动”,体现出来,就是各个像素点有不同的变化,也就是加减一个噪声。

4.对比扩散模型与流匹配

首先,最直观的感受就是:流匹配太简洁了,形式与理论高度统一,代码简洁易懂不需要数学推导

其次,是去噪的过程。

形象的比喻是:扩散模型要倒着走,一点一点猜上一步的脚印,然后倒退回起点,所以很慢,并且路径可能非常曲折。

流匹配是直接知道了目标,大步流星往回退,非常快。

之后是学习内容。

扩散模型学习的是噪声分布,这一步应该去除哪里的噪声,下一步应该去除哪里的噪声

流匹配学习的是高维的向量场,高维点应该往哪里去,带着整体移动,表现为减去噪声。

所以,流匹配之于扩散模型,就像是扩散模型之于Pixel RNN,这回真是降维打击了。

5.写在最后

流匹配算是最近看的比较舒服的理论了,因为不需要数学推导因为它的代码和理论高度统一,而且极度直观。怪不得pi0和其他VLA都在使用流匹配进行策略或者奖励函数的生成。我想起了考研方浩的名言:高山看海。从更高维度看问题,直接秒杀

(完)

posted @ 2026-01-07 22:47  阿基米德的澡盆  阅读(10)  评论(0)    收藏  举报