生成式对抗网络GAN
在传统的神经网络任务中,我们通常把一个网络当作一个函数f(x),给定输入x,网络就会输出一个对应的结果 y。比如图像分类任务中,输入是一张图片,输出是一个分类标签。这是一种 判别式模型(Discriminative Model),它学的是输入和输出之间的映射关系。但在生成式模型(Generative Model) 中,输入会增加一个随机分布中sample出来的z,网络输入x和z,输出y是可以从中采样的复杂分布。
网络通常会采用两种方式来处理x和z(1)拼接(Concatenate):直接把 两个向量拼接在一起,变成一个更长的向量,输入到神经网络中。(2)相加(Element-wise Add):如果x和 z维度相同,可以直接相加作为输入。我们的目标不再是“判断”某个输入属于哪一类,而是希望模型能够“生成”数据——比如生成看起来真实的图片、音频,甚至文本。换句话说,我们希望网络本身就是一个“生成器”,可以从某种潜在的随机性中创造出无限多样的输出。
1 GAN(Generative Adversarial Network)
1.1 GAN 的基本概念和工作原理
生成对抗网络(GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种生成式模型。与传统的神经网络不同,GAN 由两个相互竞争的网络组成:生成器(Generator) 和 判别器(Discriminator)。
以生成二次元人脸为例:
生成器(Generator):把x拿掉,生成器只输入随机噪声 z ,这种叫做unconditional generation。假设z是从normal distribution中sample出来的向量,这个向量一般是low-dim的向量,维度是自定义的,Generator输入z后产生一个64x64x3的向量,整理后可以得到一张二次元人脸的图像。
判别器(Discriminator):对应的判别器输入一张图片(可能是来自真实数据集,也可能是来自生成器),输出一个数字,数字越大表示输入的图像越像真实的二次元的人脸。
1.2 GAN 的训练机制
- 初始化生成器和判别器参数
- 在每个训练迭代中
- 固定住生成器,更新判别器。具体的随机采样一些向量z,输入到生成器,得到一些生成的图像,然后从真实数据中采样一些二次元人脸,训练判别器分别两者之间的差异。比如用二分类器,或者逻辑回归。
-
- 固定住判别器,更新生成器。具体的把两个网络接起来变成一个大网络,其中判别器的参数是固定的,训练生成器,使得分数越大越好。
GAN的训练目标是通过对抗训练,生成器和判别器在博弈中不断提高自己的能力:
- 生成器的目标:生成足够真实的数据,使得判别器无法区分它们是来自真实数据还是生成的假数据。
假设生成的数据分布为PG,真实数据分布为Pdata,生成器的目标就是
Divergence是PG和Pdata之间的某种距离,距离越小两个分布就越相近。尽管不知道PG和Pdata是什么样的分布,但是可以从中sample。
- 判别器的目标:尽可能准确地区分真实数据和生成数据,减少自己的分类错误。
分别从PG和Pdata中sample,Discriminator要学会给PG和中sample到的样本打1,给Pdata中sample到的样本打0,要做的是一个二分类任务。判别器的总 BCE 损失(对所有样本)为
损失是要最小化,等价于最大化目标函数V:
这个maxV和JS divergence有关,因此生成器的目标就是
之前的训练步骤就是在解这个minmax问题,为什么训练步骤可以解这个函数,见论文推导。设计不同的objective function解minmax问题,就对应不同的divergence(参考)。
1.3 WGAN:基于 Wasserstein 距离的改进方法
在大多数情况下,PG和Pdata之间是没有重叠的,两个都是高维空间的低维manifold,相当于二维空间的两条直线。即使两个分布是有重叠的,但是sample的点不够多,重叠的范围也非常小。当两个分布没有重叠时,JS divergence是存在问题的。 JS divergence是非度量的,不能为生成器提供一个明确的目标方向来指导其优化。比如下面的情况,PG和Pdata越来越接近,但是JS始终为log2,知道两者重合才会变成0。
为了改善 GAN 的训练稳定性,研究者提出了Wasserstein GAN(WGAN)。它的核心思想是:
替换原本的 JS 散度为 Wasserstein 距离(又叫 Earth Mover’s Distance)。
你可以把它理解为:把一堆土(生成的数据分布)搬到另一堆土堆(真实数据分布)需要花费的最小“搬运成本”。这是一种更光滑、梯度更稳定的距离度量方式。
计算Wasserstein 距离就是解下面的公式
注意:
-
WGAN 中的判别器不再是“真假分类器”,而是一个 Critic,它的输出是任意实数,用于度量样本的“真实性”得分。
-
为了满足理论条件,Critic 的梯度需要是1-Lipschitz的。这个限制是让Discriminator变得平滑,如果没有这个限制,D会给生成的x负无穷,给真实的x正无穷,训练会无法收敛,max始终是无穷大。
💡Q: 如何让Discriminator变得比较平滑?
最早用权重裁剪,限制权重在-c到c之间,后续改进为 WGAN-GP(加入梯度惩罚项),训练更稳定,另一个常用、效果非常不错的方式是:Spectral Normalization,归一化权重的最大奇异值,不需要像 WGAN-GP 那样计算复杂的梯度惩罚项,也避免了 weight clipping 带来的训练困难。
1.4 GAN 面临的挑战
尽管已经有了如WGAN这样的改进方法,GAN在训练过程中仍然面临诸多挑战。这主要源于其对抗性训练机制的本质特性:生成器(Generator)与判别器(Discriminator)在训练过程中是相互博弈、彼此促进的关系。模型的优化是一个动态博弈过程,如果其中一方(例如判别器)训练不足或性能不稳定,就会导致另一方(如生成器)无法获得有效的反馈信号,从而影响整体训练效果。这种相互依赖使得GAN的训练过程高度不稳定,调试和收敛都较为困难。目前主流的生成模型,除了GAN,还有变分自编码器(VAE),流模型(Flow-based Models)。GAN在图像生成质量上往往优于VAE和Flow模型,但它的训练更不稳定、可解释性更弱。
💡Q: 为什么GAN难用于文本生成?
答:因为文本是离散的,而GAN的训练依赖于梯度的反向传播。在文本生成中,生成器通常输出一个概率分布,再通过 argmax 或 采样 选择一个词。这种“选词”的过程是非可导的,梯度无法穿过这个离散选择,从而导致生成器无法优化。在多个值相等或者接近的时候,梯度无法确定往哪个方向优化。比如在 argmax 选词时,即便概率稍微改变,只要最大值没变,输出结果也不会变,因此梯度是 0 或未定义的,无法有效训练生成器。
💡Q: 那CNN里也有max pooling,为什么没问题?
答:CNN 中的 max pooling 虽然也是 max 操作,但它出现在网络的中间层,反向传播时我们可以将梯度传给最大值的位置,其他位置设为 0(这叫次梯度 subgradient)。这种近似梯度在实践中效果不错,因此不会影响 CNN 的训练。
模式坍缩 (Mode Collapse)
Mode Collapse 是指生成器只学会生成真实数据中的一小部分模式,导致输出缺乏多样性。换句话说,虽然真实数据有很多种可能,但生成器反复生成的是同一种或几种“看起来不错”的样本。例如:你训练 GAN 生成手写数字,但生成器最终只会生成“数字 3”,而忽略了其他数字。这样虽然图片质量可能还可以,但多样性完全丢失。这是因为生成器只专注于“骗过判别器”的目标,而不是完整地复现数据分布。一旦找到某个容易成功的样本类型,就会反复生成,从而陷入局部最优。
模式遗漏(Mode Dropping)
Mode Dropping 指的是生成器完全忽略了真实数据中的某些模式,即使这些模式在训练数据中是存在的,生成器却没有学会去生成它们。假设你训练一个 GAN 模型用于生成真实人脸,训练数据中包含了不同年龄、性别、肤色、发型的人脸图像。但训练后的生成器:只会生成某一肤色的面孔。尽管这些样本在真实数据中是存在的,生成器却忽略了这些“模式”。这就属于典型的 Mode Dropping —— 生成结果看起来多样,但其实缺失了某些重要的群体特征。
1.5 GAN评估指标
GAN 在训练中容易出现模式坍缩(Mode Collapse)和模式遗漏(Mode Dropping)等问题,即生成器生成的样本看起来质量不错,但实则重复或覆盖不全。这时我们就需要一些定量评估指标来判断两个关键问题:(1) 生成图像质量好不好?(2) 生成图像够不够多样?不像分类器有准确率指标,GAN 的 Generator 没有明确的评价指标,但有一些常用方法可以参考:
Inception Score(IS)
- 用一个预训练的分类器(如 Inception-v3)去分类生成的图像
- 一张图片丢到CNN去分类,结果分布越集中,quality越高
- 一堆图片的平均分布,越平均diversity越大
- good quality,large diversity→large IS,说明生成样本质量越好、类别多样。
在二次元人脸生成中,分类器的输出可能都是人脸,diversity小,不适合这个场景。
Frechet Inception Distane(FID)
-
将生成图像和真实图像分别输入一个预训练好的 CNN(通常是 Inception v3);
-
在网络的某一层(通常是 softmax 前的一层)提取特征向量(即使输入图像都是人脸,这些向量也会有所不同,因为它们捕捉的是更高层的语义信息(比如脸的姿态、表情、风格等)
-
假设这两组特征向量分别服从一个多维高斯分布,FID 会分别计算这两组特征向量的 均值(μ) 和 协方差矩阵(Σ),然后用 Frechet 距离来衡量它们之间的差异
如果 FID 值很小,说明生成的图像和真实图像非常接近,但这并不意味着生成器生成的样本多样化。生成器可能只是记住了真实数据的某些特征,导致它只能生成“很相似”的样本,而失去了多样性,比如学会了对真实图像进行翻转。真实特征分布可能远非高斯,所以在做次元人脸生成中主要是用FID和人眼就去看。
2 Conditional GAN(条件生成对抗网络)
传统GAN中,生成器是无条件的——它只接收随机噪声z作为输入。而 Conditional GAN 则引入条件信息x,例如类别标签、文本描述、图像等,引导生成器生成“符合条件”的样本。
应用: 文本生成图像
需要收集一些图片和对应的标注, 输入x是一段文字red eyes,可以用rnn或者transformer encoder把它变成一段向量。期望输入red eyes,generator就输出一个红眼睛的图片,每次的输出都是不一样的红眼睛,取决于sample到不一样的z。
💡Q: 如何训练Conditional GAN?
如果像之前的GAN一样,判别器只判断图像是否是真实的,生成器就不用在意输入x,只要产生清晰的图像就可以。在 Conditional GAN 中,我们的目标不只是生成“看起来真实的图像”,更重要的是:图像还要和输入的条件匹配。需要准备文字和图像成对的资料(positive),以及文字和机器产生出来的图片(negative),还需要把文字和图像乱配作为(negative)。
应用: 图像生成图像 (image translation or pix2pix)
比如输入黑白的图片让生成器着色,或者对图像去雾。输入一张图片生成一张图片,可以用supervised的方法,由于同样的输入可能对应到不一样的输出,机器学到把所有的可能平均起来,所以产生的图片会很模糊。如果用GAN的方法,再加入判别器,判别器输入生成器生成的图像和condition,然后输出分数。GAN方法产生的图像比较清楚,但是可能会产生输入没有的东西,比如下面的房子左上角有奇怪的东西。当GAN和supervised同时使用,效果会比较好,生成器在训练的时候一方面要骗过判别器,但又要使得产生的图片和目标越接近越好。
3 CycleGAN:无监督图像到图像的转换
在之前的unconditional generation中,输入是一个简单分布,输出是一个复杂分布,现在稍微转换下,输入是x domain图片的分布,输出是y domain图片的分布。假设我们将真实人脸转换为二次元人脸,像之前训练GAN一样,从x domain中sample一张图片,输入到Generator,再用学过ydomain的discriminator去给图像打分,这是有问题的。Generator会无视输入的图片,只产生一张像y domain的二次元图片就可以,这个图片和输入的真实人脸没有。如何强化输入和输出的关系呢?之前的conditional gan也讲过类似的问题,但是现在没有成对的数据去训练discriminator学习输入和输出的关系。
4 作业HW6
链接给出了PDF和code:李宏毅2021&2022机器学习。
代码:

import random import torch import numpy as np import os import glob import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch import optim from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from tqdm import tqdm def same_seeds(seed): # Python built-in random module random.seed(seed) # Numpy np.random.seed(seed) # Torch torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True same_seeds(2021) class CrypkoDataset(Dataset): def __init__(self, fnames, transform): self.transform = transform self.fnames = fnames self.num_samples = len(self.fnames) def __getitem__(self, idx): fname = self.fnames[idx] # 1. Load the image img = torchvision.io.read_image(fname) # 2. Resize and normalize the images using torchvision. img = self.transform(img) return img def __len__(self): return self.num_samples def get_dataset(root): fnames = glob.glob(os.path.join(root, '*')) # 1. Resize the image to (64, 64) # 2. Linearly map [0, 1] to [-1, 1] compose = [ transforms.ToPILImage(), transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ] transform = transforms.Compose(compose) dataset = CrypkoDataset(fnames, transform) return dataset def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) class Generator(nn.Module): """ Input shape: (N, in_dim) Output shape: (N, 3, 64, 64) """ def __init__(self, in_dim, dim=64): super(Generator, self).__init__() def dconv_bn_relu(in_dim, out_dim): return nn.Sequential( nn.ConvTranspose2d(in_dim, out_dim, 5, 2, padding=2, output_padding=1, bias=False), nn.BatchNorm2d(out_dim), nn.ReLU() ) self.l1 = nn.Sequential( nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), nn.BatchNorm1d(dim * 8 * 4 * 4), nn.ReLU() ) self.l2_5 = nn.Sequential( dconv_bn_relu(dim * 8, dim * 4), dconv_bn_relu(dim * 4, dim * 2), dconv_bn_relu(dim * 2, dim), nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1), nn.Tanh() ) self.apply(weights_init) def forward(self, x): y = self.l1(x) y = y.view(y.size(0), -1, 4, 4) y = self.l2_5(y) return y class Discriminator(nn.Module): """ Input shape: (N, 3, 64, 64) Output shape: (N, ) """ def __init__(self, in_dim, dim=64, use_sigmoid=True): super(Discriminator, self).__init__() def conv_bn_lrelu(in_dim, out_dim): return nn.Sequential( nn.Conv2d(in_dim, out_dim, 5, 2, 2), nn.BatchNorm2d(out_dim), nn.LeakyReLU(0.2), ) """ Medium: Remove the last sigmoid layer for WGAN. """ layers = [ nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2), conv_bn_lrelu(dim, dim * 2), conv_bn_lrelu(dim * 2, dim * 4), conv_bn_lrelu(dim * 4, dim * 8), nn.Conv2d(dim * 8, 1, 4), ] if use_sigmoid: layers.append(nn.Sigmoid()) self.ls = nn.Sequential(*layers) self.apply(weights_init) def forward(self, x): y = self.ls(x) y = y.view(-1) return y def train(baseline="Simple", show_img=True): # Training hyperparameters batch_size = 64 z_sample = Variable(torch.randn(100, z_dim)).cuda() lr = 1e-4 if baseline == "Simple": n_epoch = 50 # 50 n_critic = 1 # 训练 1 次判别器,再训练 1 次生成器 elif baseline == "Medium": """ Medium: WGAN, 50 epoch, n_critic=5, clip_value=0.01 """ n_epoch = 50 n_critic = 5 # 先训练 5 次判别器,再训练 1 次生成器 clip_value = 0.01 # Model G = Generator(in_dim=z_dim).cuda() if baseline == "Simple": D = Discriminator(3).cuda() elif baseline == "Medium": D = Discriminator(3, use_sigmoid=False).cuda() G.train() D.train() # Loss criterion = nn.BCELoss() # Optimizer if baseline == "Simple": opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) elif baseline == "Medium": """ Medium: Use RMSprop for WGAN. """ opt_D = torch.optim.RMSprop(D.parameters(), lr=lr) opt_G = torch.optim.RMSprop(G.parameters(), lr=lr) # DataLoader dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) steps = 0 for e, epoch in enumerate(range(n_epoch)): progress_bar = tqdm(dataloader) for i, data in enumerate(progress_bar): imgs = data imgs = imgs.cuda() bs = imgs.size(0) # ============================================ # Train D # ============================================ z = Variable(torch.randn(bs, z_dim)).cuda() r_imgs = Variable(imgs).cuda() f_imgs = G(z) if baseline == "Simple": # Label r_label = torch.ones((bs)).cuda() f_label = torch.zeros((bs)).cuda() # Model forwarding r_logit = D(r_imgs.detach()) f_logit = D(f_imgs.detach()) # Compute the loss for the discriminator. r_loss = criterion(r_logit, r_label) f_loss = criterion(f_logit, f_label) loss_D = (r_loss + f_loss) / 2 elif baseline == "Medium": # WGAN Loss loss_D = -torch.mean(D(r_imgs)) + torch.mean(D(f_imgs)) # Model backwarding D.zero_grad() loss_D.backward() # Update the discriminator. opt_D.step() if baseline == "Medium": """ Medium: Clip weights of discriminator. """ for p in D.parameters(): p.data.clamp_(-clip_value, clip_value) # ============================================ # Train G # ============================================ if steps % n_critic == 0: # Generate some fake images. z = Variable(torch.randn(bs, z_dim)).cuda() f_imgs = G(z) # Model forwarding f_logit = D(f_imgs) if baseline == "Simple": # Compute the loss for the generator. loss_G = criterion(f_logit, r_label) elif baseline == "Medium": # WGAN Loss loss_G = -torch.mean(D(f_imgs)) # Model backwarding G.zero_grad() loss_G.backward() # Update the generator. opt_G.step() steps += 1 # Set the info of the progress bar # Note that the value of the GAN loss is not directly related to # the quality of the generated images. progress_bar.set_postfix({ 'Loss_D': round(loss_D.item(), 4), 'Loss_G': round(loss_G.item(), 4), 'Epoch': e + 1, 'Step': steps, }) G.eval() f_imgs_sample = (G(z_sample).data + 1) / 2.0 filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg') torchvision.utils.save_image(f_imgs_sample, filename, nrow=10) print(f' | Save some samples to {filename}.') # Show generated images in the jupyter notebook. if show_img: grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10) plt.figure(figsize=(10, 10)) plt.imshow(grid_img.permute(1, 2, 0)) plt.show() G.train() if (e + 1) % 5 == 0 or e == 0: # Save the checkpoints. torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth')) torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth')) def inference(): G = Generator(z_dim) G.load_state_dict(torch.load(os.path.join(ckpt_dir, 'G.pth'))) G.eval() G.cuda() # Generate 1000 images and make a grid to save them. n_output = 1000 z_sample = Variable(torch.randn(n_output, z_dim)).cuda() imgs_sample = (G(z_sample).data + 1) / 2.0 log_dir = os.path.join('logs') filename = os.path.join(log_dir, 'result.jpg') torchvision.utils.save_image(imgs_sample, filename, nrow=10) # Show 30 of the images. grid_img = torchvision.utils.make_grid(imgs_sample[:30].cpu(), nrow=10) plt.figure(figsize=(10, 10)) plt.imshow(grid_img.permute(1, 2, 0)) plt.show() if __name__ == '__main__': dataset = get_dataset('faces') # 注意,这些数值的范围是 [-1, 1],所以显示比较暗 # images = [dataset[i] for i in range(16)] # grid_img = torchvision.utils.make_grid(images, nrow=4) # plt.figure(figsize=(10, 10)) # plt.imshow(grid_img.permute(1, 2, 0)) # plt.show() # 我们需要将它们转换到有效的范围 [0, 1],才能正确显示。 # images = [(dataset[i] + 1) / 2 for i in range(16)] # grid_img = torchvision.utils.make_grid(images, nrow=4) # plt.figure(figsize=(10, 10)) # plt.imshow(grid_img.permute(1, 2, 0)) # plt.show() z_dim = 100 log_dir = os.path.join('logs') ckpt_dir = os.path.join('checkpoints') os.makedirs(log_dir, exist_ok=True) os.makedirs(ckpt_dir, exist_ok=True) train(baseline="Medium", show_img=False) inference()
两个指标:
-
FID(Frechet Inception Distance):衡量生成图片和真实图片的差异,越低越好。
-
AFD(Attribute FID Distance):衡量生成图像属性的多样性或质量,越高越好。
项目 | 评分标准 | 分数 |
---|---|---|
✅ 代码部分 Code | 提交完整、可运行的代码 | 4 分 |
✅ 简单基准 Simple | FID ≤ 30000 且 AFD ≥ 0.00 | 2 分 |
✅ 中等基准 Medium | FID ≤ 11800 且 AFD ≥ 0.43 | 2 分 |
✅ 强基准 Strong | FID ≤ 9300 且 AFD ≥ 0.53 | 1 分 |
✅ 最强基准 Boss | FID ≤ 8200 且 AFD ≥ 0.68 | 1 分 |
🎁 额外加分 Bonus | 击败 Boss 基准 + 提交 < 100 字的英文 PDF 报告 | 0.5 分 |
- 从判别器中移除最后的 sigmoid 层。
- 计算损失时不取对数(log)。
- 将判别器的权重裁剪到一个常数范围内。
- 使用 RMSProp 或 SGD 作为优化器。
用Spectral Normalization GAN (SNGAN)可以达到strong分数,主要是在判别器的每一层的权重进行谱归一化(Spectral Normalization)。
训练50epoch结果
训练50epoch结果