生成式对抗网络GAN

【 李宏毅机器学习】生成式对抗网络GAN

  在传统的神经网络任务中,我们通常把一个网络当作一个函数f(x),给定输入x,网络就会输出一个对应的结果 y。比如图像分类任务中,输入是一张图片,输出是一个分类标签。这是一种 判别式模型(Discriminative Model),它学的是输入和输出之间的映射关系。但在生成式模型(Generative Model) 中,输入会增加一个随机分布中sample出来的z,网络输入x和z,输出y是可以从中采样的复杂分布。

网络通常会采用两种方式来处理x和z(1)拼接(Concatenate):直接把 两个向量拼接在一起,变成一个更长的向量,输入到神经网络中。(2)相加(Element-wise Add):如果x和 z维度相同,可以直接相加作为输入。我们的目标不再是“判断”某个输入属于哪一类,而是希望模型能够“生成”数据——比如生成看起来真实的图片、音频,甚至文本。换句话说,我们希望网络本身就是一个“生成器”,可以从某种潜在的随机性中创造出无限多样的输出。

1 GAN(Generative Adversarial Network)

1.1 GAN 的基本概念和工作原理

生成对抗网络(GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种生成式模型。与传统的神经网络不同,GAN 由两个相互竞争的网络组成:生成器(Generator)判别器(Discriminator)

以生成二次元人脸为例:

生成器(Generator):把x拿掉,生成器只输入随机噪声 z ,这种叫做unconditional generation。假设z是从normal distribution中sample出来的向量,这个向量一般是low-dim的向量,维度是自定义的,Generator输入z后产生一个64x64x3的向量,整理后可以得到一张二次元人脸的图像。

判别器(Discriminator):对应的判别器输入一张图片(可能是来自真实数据集,也可能是来自生成器),输出一个数字,数字越大表示输入的图像越像真实的二次元的人脸。

1.2 GAN 的训练机制

  • 初始化生成器和判别器参数
  • 在每个训练迭代中
    • 固定住生成器,更新判别器。具体的随机采样一些向量z,输入到生成器,得到一些生成的图像,然后从真实数据中采样一些二次元人脸,训练判别器分别两者之间的差异。比如用二分类器,或者逻辑回归。

        

    •  固定住判别器,更新生成器。具体的把两个网络接起来变成一个大网络,其中判别器的参数是固定的,训练生成器,使得分数越大越好。

              

GAN的训练目标是通过对抗训练,生成器和判别器在博弈中不断提高自己的能力:

  • 生成器的目标:生成足够真实的数据,使得判别器无法区分它们是来自真实数据还是生成的假数据。

假设生成的数据分布为PG,真实数据分布为Pdata,生成器的目标就是

Divergence是PG和Pdata之间的某种距离,距离越小两个分布就越相近。尽管不知道PG和Pdata是什么样的分布,但是可以从中sample。

  • 判别器的目标:尽可能准确地区分真实数据和生成数据,减少自己的分类错误。

分别从PG和Pdata中sample,Discriminator要学会给PG和中sample到的样本打1,给Pdata中sample到的样本打0,要做的是一个二分类任务。判别器的总 BCE 损失(对所有样本)为

   损失是要最小化,等价于最大化目标函数V:

 

   这个maxV和JS divergence有关,因此生成器的目标就是

之前的训练步骤就是在解这个minmax问题,为什么训练步骤可以解这个函数,见论文推导。设计不同的objective function解minmax问题,就对应不同的divergence(参考)。

1.3 WGAN:基于 Wasserstein 距离的改进方法

在大多数情况下,PG和Pdata之间是没有重叠的,两个都是高维空间的低维manifold,相当于二维空间的两条直线。即使两个分布是有重叠的,但是sample的点不够多,重叠的范围也非常小。当两个分布没有重叠时,JS divergence是存在问题的。 JS divergence是非度量的,不能为生成器提供一个明确的目标方向来指导其优化。比如下面的情况,PG和Pdata越来越接近,但是JS始终为log2,知道两者重合才会变成0。

为了改善 GAN 的训练稳定性,研究者提出了Wasserstein GAN(WGAN)。它的核心思想是:

替换原本的 JS 散度为 Wasserstein 距离(又叫 Earth Mover’s Distance)。

你可以把它理解为:把一堆土(生成的数据分布)搬到另一堆土堆(真实数据分布)需要花费的最小“搬运成本”。这是一种更光滑、梯度更稳定的距离度量方式。

计算Wasserstein 距离就是解下面的公式

注意:

  • WGAN 中的判别器不再是“真假分类器”,而是一个 Critic,它的输出是任意实数,用于度量样本的“真实性”得分。

  • 为了满足理论条件,Critic 的梯度需要是1-Lipschitz的。这个限制是让Discriminator变得平滑,如果没有这个限制,D会给生成的x负无穷,给真实的x正无穷,训练会无法收敛,max始终是无穷大。

💡Q: 如何让Discriminator变得比较平滑?

最早用权重裁剪,限制权重在-c到c之间,后续改进为 WGAN-GP(加入梯度惩罚项),训练更稳定,另一个常用、效果非常不错的方式是:Spectral Normalization,归一化权重的最大奇异值,不需要像 WGAN-GP 那样计算复杂的梯度惩罚项,也避免了 weight clipping 带来的训练困难。

1.4 GAN 面临的挑战

尽管已经有了如WGAN这样的改进方法,GAN在训练过程中仍然面临诸多挑战。这主要源于其对抗性训练机制的本质特性:生成器(Generator)与判别器(Discriminator)在训练过程中是相互博弈、彼此促进的关系。模型的优化是一个动态博弈过程,如果其中一方(例如判别器)训练不足或性能不稳定,就会导致另一方(如生成器)无法获得有效的反馈信号,从而影响整体训练效果。这种相互依赖使得GAN的训练过程高度不稳定,调试和收敛都较为困难。目前主流的生成模型,除了GAN,还有变分自编码器(VAE)流模型(Flow-based Models)。GAN在图像生成质量上往往优于VAE和Flow模型,但它的训练更不稳定、可解释性更弱。

💡Q: 为什么GAN难用于文本生成?

答:因为文本是离散的,而GAN的训练依赖于梯度的反向传播。在文本生成中,生成器通常输出一个概率分布,再通过 argmax采样 选择一个词。这种“选词”的过程是非可导的,梯度无法穿过这个离散选择,从而导致生成器无法优化。在多个值相等或者接近的时候,梯度无法确定往哪个方向优化。比如在 argmax 选词时,即便概率稍微改变,只要最大值没变,输出结果也不会变,因此梯度是 0 或未定义的,无法有效训练生成器。

💡Q: 那CNN里也有max pooling,为什么没问题?

答:CNN 中的 max pooling 虽然也是 max 操作,但它出现在网络的中间层,反向传播时我们可以将梯度传给最大值的位置,其他位置设为 0(这叫次梯度 subgradient)。这种近似梯度在实践中效果不错,因此不会影响 CNN 的训练。

模式坍缩 (Mode Collapse)

Mode Collapse 是指生成器只学会生成真实数据中的一小部分模式,导致输出缺乏多样性。换句话说,虽然真实数据有很多种可能,但生成器反复生成的是同一种或几种“看起来不错”的样本。
例如:你训练 GAN 生成手写数字,但生成器最终只会生成“数字 3”,而忽略了其他数字。这样虽然图片质量可能还可以,但多样性完全丢失。这是因为生成器只专注于“骗过判别器”的目标,而不是完整地复现数据分布。一旦找到某个容易成功的样本类型,就会反复生成,从而陷入局部最优。

模式遗漏(Mode Dropping)

Mode Dropping 指的是生成器完全忽略了真实数据中的某些模式,即使这些模式在训练数据中是存在的,生成器却没有学会去生成它们。假设你训练一个 GAN 模型用于生成真实人脸,训练数据中包含了不同年龄、性别、肤色、发型的人脸图像。但训练后的生成器:只会生成某一肤色的面孔。尽管这些样本在真实数据中是存在的,生成器却忽略了这些“模式”。这就属于典型的 Mode Dropping —— 生成结果看起来多样,但其实缺失了某些重要的群体特征

1.5 GAN评估指标

GAN 在训练中容易出现模式坍缩(Mode Collapse)模式遗漏(Mode Dropping)等问题,即生成器生成的样本看起来质量不错,但实则重复或覆盖不全。这时我们就需要一些定量评估指标来判断两个关键问题:(1) 生成图像质量好不好?(2) 生成图像够不够多样?不像分类器有准确率指标,GAN 的 Generator 没有明确的评价指标,但有一些常用方法可以参考:

Inception Score(IS)

  • 用一个预训练的分类器(如 Inception-v3)去分类生成的图像
  • 一张图片丢到CNN去分类,结果分布越集中,quality越高
  • 一堆图片的平均分布,越平均diversity越大
  • good quality,large diversity→large IS,说明生成样本质量越好、类别多样。

在二次元人脸生成中,分类器的输出可能都是人脸,diversity小,不适合这个场景。

Frechet Inception Distane(FID)

  • 将生成图像和真实图像分别输入一个预训练好的 CNN(通常是 Inception v3);

  • 在网络的某一层(通常是 softmax 前的一层)提取特征向量(即使输入图像都是人脸,这些向量也会有所不同,因为它们捕捉的是更高层的语义信息(比如脸的姿态、表情、风格等)

  • 假设这两组特征向量分别服从一个多维高斯分布,FID 会分别计算这两组特征向量的 均值(μ)协方差矩阵(Σ),然后用 Frechet 距离来衡量它们之间的差异

如果 FID 值很小,说明生成的图像和真实图像非常接近,但这并不意味着生成器生成的样本多样化。生成器可能只是记住了真实数据的某些特征,导致它只能生成“很相似”的样本,而失去了多样性,比如学会了对真实图像进行翻转。真实特征分布可能远非高斯,所以在做次元人脸生成中主要是用FID和人眼就去看。

2 Conditional GAN(条件生成对抗网络)

传统GAN中,生成器是无条件的——它只接收随机噪声z作为输入。而 Conditional GAN 则引入条件信息x,例如类别标签、文本描述、图像等,引导生成器生成“符合条件”的样本。

应用: 文本生成图像

需要收集一些图片和对应的标注, 输入x是一段文字red eyes,可以用rnn或者transformer encoder把它变成一段向量。期望输入red eyes,generator就输出一个红眼睛的图片,每次的输出都是不一样的红眼睛,取决于sample到不一样的z。

💡Q: 如何训练Conditional GAN?

如果像之前的GAN一样,判别器只判断图像是否是真实的,生成器就不用在意输入x,只要产生清晰的图像就可以。在 Conditional GAN 中,我们的目标不只是生成“看起来真实的图像”,更重要的是:图像还要和输入的条件匹配。需要准备文字和图像成对的资料(positive),以及文字和机器产生出来的图片(negative),还需要把文字和图像乱配作为(negative)。

应用: 图像生成图像 (image translation or pix2pix)

比如输入黑白的图片让生成器着色,或者对图像去雾。输入一张图片生成一张图片,可以用supervised的方法,由于同样的输入可能对应到不一样的输出,机器学到把所有的可能平均起来,所以产生的图片会很模糊。如果用GAN的方法,再加入判别器,判别器输入生成器生成的图像和condition,然后输出分数。GAN方法产生的图像比较清楚,但是可能会产生输入没有的东西,比如下面的房子左上角有奇怪的东西。当GAN和supervised同时使用,效果会比较好,生成器在训练的时候一方面要骗过判别器,但又要使得产生的图片和目标越接近越好。

3 CycleGAN:无监督图像到图像的转换

  在之前的unconditional generation中,输入是一个简单分布,输出是一个复杂分布,现在稍微转换下,输入是x domain图片的分布,输出是y domain图片的分布。假设我们将真实人脸转换为二次元人脸,像之前训练GAN一样,从x domain中sample一张图片,输入到Generator,再用学过ydomain的discriminator去给图像打分,这是有问题的。Generator会无视输入的图片,只产生一张像y domain的二次元图片就可以,这个图片和输入的真实人脸没有。如何强化输入和输出的关系呢?之前的conditional gan也讲过类似的问题,但是现在没有成对的数据去训练discriminator学习输入和输出的关系。

  CycleGAN(由 Jun-Yan Zhu 等人在 2017 年提出,旨在解决没有成对数据的图像转换问题。传统的图像到图像的转换任务(比如图像翻译、风格转换)通常需要输入和目标图像一一对应的成对数据。然而,在实际应用中,往往没有这样的配对数据集,这就使得传统的 GAN 方法难以应用。CycleGAN 通过设计一对生成器和判别器,解决了这一问题。它的核心思想是 循环一致性(Cycle Consistency),即:如果将一张图像从一个域映射到另一个域,然后再映射回来,应该得到原始的图像。这样的映射关系可以确保生成的图像既能保持域间的转换效果,又不失原始图像的特征。以真实人脸转换为二次元人脸为例,在训练的时候增加一个额外的目标,输入在经过x到y domain的转换后再经过y到xdomain的转换,两次转换后和原来的向量越接近越好,同时也可以做另一个方向上的训练。

4 作业HW6

链接给出了PDF和code:李宏毅2021&2022机器学习

代码:

import random

import torch
import numpy as np
import os
import glob

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm


def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


same_seeds(2021)


class CrypkoDataset(Dataset):
    def __init__(self, fnames, transform):
        self.transform = transform
        self.fnames = fnames
        self.num_samples = len(self.fnames)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        # 1. Load the image
        img = torchvision.io.read_image(fname)
        # 2. Resize and normalize the images using torchvision.
        img = self.transform(img)
        return img

    def __len__(self):
        return self.num_samples


def get_dataset(root):
    fnames = glob.glob(os.path.join(root, '*'))
    # 1. Resize the image to (64, 64)
    # 2. Linearly map [0, 1] to [-1, 1]
    compose = [
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
    transform = transforms.Compose(compose)
    dataset = CrypkoDataset(fnames, transform)
    return dataset


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Generator(nn.Module):
    """
    Input shape: (N, in_dim)
    Output shape: (N, 3, 64, 64)
    """

    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()

        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU()
        )
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(weights_init)

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y


class Discriminator(nn.Module):
    """
    Input shape: (N, 3, 64, 64)
    Output shape: (N, )
    """

    def __init__(self, in_dim, dim=64, use_sigmoid=True):
        super(Discriminator, self).__init__()

        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2),
            )

        """ Medium: Remove the last sigmoid layer for WGAN. """
        layers = [
            nn.Conv2d(in_dim, dim, 5, 2, 2),
            nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
        ]

        if use_sigmoid:
            layers.append(nn.Sigmoid())

        self.ls = nn.Sequential(*layers)

        self.apply(weights_init)

    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y


def train(baseline="Simple", show_img=True):
    # Training hyperparameters
    batch_size = 64
    z_sample = Variable(torch.randn(100, z_dim)).cuda()
    lr = 1e-4

    if baseline == "Simple":
        n_epoch = 50  # 50
        n_critic = 1  # 训练 1 次判别器,再训练 1 次生成器
    elif baseline == "Medium":
        """ Medium: WGAN, 50 epoch, n_critic=5, clip_value=0.01 """
        n_epoch = 50
        n_critic = 5  # 先训练 5 次判别器,再训练 1 次生成器
        clip_value = 0.01

    # Model
    G = Generator(in_dim=z_dim).cuda()
    if baseline == "Simple":
        D = Discriminator(3).cuda()
    elif baseline == "Medium":
        D = Discriminator(3, use_sigmoid=False).cuda()
    G.train()
    D.train()

    # Loss
    criterion = nn.BCELoss()

    # Optimizer
    if baseline == "Simple":
        opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
        opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    elif baseline == "Medium":
        """ Medium: Use RMSprop for WGAN. """
        opt_D = torch.optim.RMSprop(D.parameters(), lr=lr)
        opt_G = torch.optim.RMSprop(G.parameters(), lr=lr)

    # DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    steps = 0
    for e, epoch in enumerate(range(n_epoch)):
        progress_bar = tqdm(dataloader)
        for i, data in enumerate(progress_bar):
            imgs = data
            imgs = imgs.cuda()

            bs = imgs.size(0)

            # ============================================
            #  Train D
            # ============================================
            z = Variable(torch.randn(bs, z_dim)).cuda()
            r_imgs = Variable(imgs).cuda()
            f_imgs = G(z)

            if baseline == "Simple":
                # Label
                r_label = torch.ones((bs)).cuda()
                f_label = torch.zeros((bs)).cuda()

                # Model forwarding
                r_logit = D(r_imgs.detach())
                f_logit = D(f_imgs.detach())

                # Compute the loss for the discriminator.
                r_loss = criterion(r_logit, r_label)
                f_loss = criterion(f_logit, f_label)
                loss_D = (r_loss + f_loss) / 2
            elif baseline == "Medium":
                # WGAN Loss
                loss_D = -torch.mean(D(r_imgs)) + torch.mean(D(f_imgs))

            # Model backwarding
            D.zero_grad()
            loss_D.backward()

            # Update the discriminator.
            opt_D.step()

            if baseline == "Medium":
                """ Medium: Clip weights of discriminator. """
                for p in D.parameters():
                   p.data.clamp_(-clip_value, clip_value)

            # ============================================
            #  Train G
            # ============================================
            if steps % n_critic == 0:
                # Generate some fake images.
                z = Variable(torch.randn(bs, z_dim)).cuda()
                f_imgs = G(z)

                # Model forwarding
                f_logit = D(f_imgs)

                if baseline == "Simple":
                    # Compute the loss for the generator.
                    loss_G = criterion(f_logit, r_label)
                elif baseline == "Medium":
                    # WGAN Loss
                    loss_G = -torch.mean(D(f_imgs))

                # Model backwarding
                G.zero_grad()
                loss_G.backward()

                # Update the generator.
                opt_G.step()

            steps += 1

            # Set the info of the progress bar
            #   Note that the value of the GAN loss is not directly related to
            #   the quality of the generated images.
            progress_bar.set_postfix({
                'Loss_D': round(loss_D.item(), 4),
                'Loss_G': round(loss_G.item(), 4),
                'Epoch': e + 1,
                'Step': steps,
            })

        G.eval()
        f_imgs_sample = (G(z_sample).data + 1) / 2.0
        filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg')
        torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
        print(f' | Save some samples to {filename}.')

        # Show generated images in the jupyter notebook.
        if show_img:
            grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid_img.permute(1, 2, 0))
            plt.show()
        G.train()

        if (e + 1) % 5 == 0 or e == 0:
            # Save the checkpoints.
            torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
            torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))


def inference():
    G = Generator(z_dim)
    G.load_state_dict(torch.load(os.path.join(ckpt_dir, 'G.pth')))
    G.eval()
    G.cuda()
    # Generate 1000 images and make a grid to save them.
    n_output = 1000
    z_sample = Variable(torch.randn(n_output, z_dim)).cuda()
    imgs_sample = (G(z_sample).data + 1) / 2.0
    log_dir = os.path.join('logs')
    filename = os.path.join(log_dir, 'result.jpg')
    torchvision.utils.save_image(imgs_sample, filename, nrow=10)

    # Show 30 of the images.
    grid_img = torchvision.utils.make_grid(imgs_sample[:30].cpu(), nrow=10)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()


if __name__ == '__main__':
    dataset = get_dataset('faces')

    # 注意,这些数值的范围是 [-1, 1],所以显示比较暗
    # images = [dataset[i] for i in range(16)]
    # grid_img = torchvision.utils.make_grid(images, nrow=4)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(grid_img.permute(1, 2, 0))
    # plt.show()

    # 我们需要将它们转换到有效的范围 [0, 1],才能正确显示。
    # images = [(dataset[i] + 1) / 2 for i in range(16)]
    # grid_img = torchvision.utils.make_grid(images, nrow=4)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(grid_img.permute(1, 2, 0))
    # plt.show()

    z_dim = 100
    log_dir = os.path.join('logs')
    ckpt_dir = os.path.join('checkpoints')
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)

    train(baseline="Medium", show_img=False)

    inference()
View Code

两个指标:

  • FID(Frechet Inception Distance):衡量生成图片和真实图片的差异,越低越好。

  • AFD(Attribute FID Distance):衡量生成图像属性的多样性或质量,越高越好。

项目评分标准分数
✅ 代码部分 Code 提交完整、可运行的代码 4 分
✅ 简单基准 Simple FID ≤ 30000 且 AFD ≥ 0.00 2 分
✅ 中等基准 Medium FID ≤ 11800 且 AFD ≥ 0.43 2 分
✅ 强基准 Strong FID ≤ 9300 且 AFD ≥ 0.53 1 分
✅ 最强基准 Boss FID ≤ 8200 且 AFD ≥ 0.68 1 分
🎁 额外加分 Bonus 击败 Boss 基准 + 提交 < 100 字的英文 PDF 报告 0.5 分
 代码给出的例子是DCGAN,用WGAN可以达到medium分数,要做的修改主要是
  • 从判别器中移除最后的 sigmoid 层。
  • 计算损失时不取对数(log)。
  • 将判别器的权重裁剪到一个常数范围内。
  • 使用 RMSProp 或 SGD 作为优化器。

用Spectral Normalization GAN (SNGAN)可以达到strong分数,主要是在判别器的每一层的权重进行谱归一化(Spectral Normalization)。

 

训练50epoch结果

 

训练50epoch结果

 

posted @ 2025-04-21 09:34  湾仔码农  阅读(350)  评论(0)    收藏  举报