GAN生成对抗

Crypko利用GAN生成二次元老婆,那出于好奇(绝不是因为想要二次元老婆),是不是也可以自行训练GAN模型,去自行创造一下。
以下代码基于Pytorch,另外模型定义部分借鉴DCGAN,如果有读者是TensorFlow使用者,对pytorch不熟悉,可以直接去阅读DCGAN的源码(DCGAN是用TensorFlow写的)

1 数据集、模型定义

1.1 数据集定义

数据集资源获取是从[1]得到的:https://pan.baidu.com/s/1gJ00Uipghq991Piq-j29dQ 提取码:2hls

import os
import PIL.Image as Image
from torch.utils.data import Dataset

class AnimeDataset(Dataset):
    def __init__(self, data_dir, trans=None, filter_list=None):
        """
        Args:
            data_dir: 数据目录
            trans: 加载时对数据做的变形
        """
        super(AnimeDataset, self).__init__()
        self.data_dir = data_dir
        self.trans = trans
        if filter_list is None:
            filter_list = ['.jpg', '.jpeg', '.webp', '.bmp']
        self.img_names = [name for name in list(filter(
              lambda x: x.endswith(tuple(filter_list)), os.listdir(self.data_dir)
        ))]

    def __getitem__(self, index):
        path_img = os.path.join(self.data_dir, self.img_names[index])
        img = Image.open(path_img).convert('RGB')
        if self.trans is not None:
            img = self.trans(img)
        return img

    def __len__(self):
        n = len(self.img_names)
        if n == 0:
            raise Exception('该路径下没有图片,请重新检查')
        return n

1.2 Generator

Generator使用了反卷积(逆卷积、转置卷积),这里做一个简单说明。

对卷积,我们会比较熟悉,相关计算公式易知:

\[w^{'}=\frac{w+2*padding-kernel\_size}{stride}+1\\ h^{'}=\frac{h+2*padding-kernel\_size}{stride}+1 \]

\(w^{'}\)\(h^{'}\)是经过卷积后,图像的宽高。
那反卷积其实是相反的过程,给定卷积后的\(w^{'}\)\(h^{'}\),计算卷积之前的\(w\)\(h\),计算公式其实就是上面两个式子进行移项而已。

\[w=strides(w'-1)+kernel\_size-2*padding\\ h=strides(h'-1)+kernel\_size-2*padding \]

class Generator(nn.Module):
    class Generator(nn.Module):
    def __init__(self, nzd=100, ngf=64, c=3):
        """
        Args:
            nzd: noisy vector channel dim
            ngf: number of generator feature
            c: channel
        """
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            # (nzd, 1, 1)->(ngf * 8, 6, 6)
            *self.create_block(nzd, ngf * 8, 6, 1, 0),
            # (ngf * 8, 6, 6)->(ngf * 4, 12, 12)
            *self.create_block(ngf * 8, ngf * 4),
            # (ngf * 4, 12, 12)->(ngf * 2, 24, 24)
            *self.create_block(ngf * 4, ngf * 2),
            # (ngf * 2, 24, 24)->(ngf, 48, 48)
            *self.create_block(ngf * 2, ngf),
            # (ngf, 48, 48)->(c, 96, 96)
            *self.create_block(ngf * 2, ngf, last=True),
        )

    def create_block(self, in_channel, out_channel, kernel_size=4, stride=2, padding=1, last=False):
        layer_list = [nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False), ]
        if last:
            layer_list.append(nn.Tanh())
        else:
            layer_list.extend([nn.BatchNorm2d(out_channel),
                               nn.ReLU(inplace=True), ])
        return layer_list

    def forward(self, X):
        return self.net(X)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        init_weights(self.modules(), w_mean, w_std, b_mean, b_std)

1.3 Discriminator

class Discriminator(nn.Module):
    def __init__(self, ndf=64, c=3):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # (c, 96, 96)->(ndf, 48, 48)
            *self.create_block(c, ndf),
            # (ndf, 48, 48)->(ndf*2, 24, 24)
            *self.create_block(ndf, ndf * 2),
            # (ndf*2, 24, 24)->(ndf*4, 12, 12)
            *self.create_block(ndf * 2, ndf * 4),
            # (ndf*4, 12, 12)->(ndf*8, 6, 6)
            *self.create_block(ndf * 4, ndf * 8),
            # (ndf * 8, 6, 6)->(ndf * 16, 3, 3)
            *self.create_block(ndf * 8, ndf * 16),
            # (ndf * 16, 3, 3)->(1, 3, 3)
            *self.create_block(ndf * 16, 1, 3, 1, 0, last=True)
        )

    def forward(self, X):
        return self.net(X)

    def create_block(self, in_channel, out_channel, kernel_size=4, stride=2, padding=1, last=False):
        layer_list = [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)]
        if last:
            layer_list.append(nn.Sigmoid())
        else:
            layer_list.extend([nn.BatchNorm2d(out_channel),
                               nn.LeakyReLU(0.2, inplace=True)])
        return layer_list

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        init_weights(self.modules(), w_mean, w_std, b_mean, b_std)

1.4 公用的函数

import os
import torchvision.transforms as transforms
import imageio
from torch import nn
from torch.utils.data import DataLoader
from my_utils.animedataset import AnimeDataset

def init_weights(modules, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
    """
    初始化模型参数
    """
    for m in modules:
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weigth.data, w_mean, w_std)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weigth.data, b_mean, b_std)
            nn.init.constant_(m.bias.data, 0)

def load_data(filepath, batch_size, trans=None):
    """
    加载数据集
    """
    train_set = AnimeDataset(filepath, trans)
    return DataLoader(train_set, batch_size=batch_size, num_workers=1, shuffle=True)


def init_trans(image_size):
    """
    对图片做的变形处理
    """
    return transforms.Compose([transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])

def gen_gif(src, suffix, dst, filename):
    """
    从src中读取后缀为suffix的图片,处理成动态图,保存到dst中,文件名是filename
    Args:
        src: 源目录
        suffix: 后缀
        dst: 动态图保存的目录
        filename: 动态的文件名
    """
    # 由于我保存epoch测试图片的文件名格式:{epoch}_epoch.png,所以通过'_'split后的[0]是一个数字
    # 根据实际情况自行调整
    imgs_epoch = [int(name.split("_")[0]) for name in
                  list(filter(lambda x: x.endswith(suffix), os.listdir(src)))]
    imgs_epoch = sorted(imgs_epoch)

    imgs = list()
    for i in range(len(imgs_epoch)):
        img_name = os.path.join(src, f"{imgs_epoch[i]}{suffix}")
        imgs.append(imageio.imread(img_name))

    imageio.mimsave(os.path.join(dst, filename), imgs, fps=2)

2 训练

import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import os
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from visdom import Visdom
from my_utils.model import Discriminator, Generator
from my_utils.tools import load_data, init_trans, gen_gif

# 定义一些变量方便后序使用
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dirs = {
    'img': './img',
    'img_data': '../data/anime/face',
    'model': './model'
}
image_size, checkpoint_interval, record_loss_interval = 96, 10, 10  # 图片的大小,模型保存的间隔(每隔10个epoch保存一次模型)
real_img_label, fake_img_label = 0.9, 0.1  # 真\假图片的标签
# 超参数(要是显存够大可以将batch_size, nzd, ngf调大一些)
epochs, lr, batch_size, beta1, nzd, ngf, ndf, channel = 20, 1e-4, 10, 0.5, 50, 64, 64, 3
# fixed_noise是用来测试generator net训练效果
fixed_noise = torch.randn(64, nzd, 1, 1, device=device)

viz = Visdom()

if __name__ == '__main__':
    # 1 加载数据
    train_iter = load_data(dirs['img_data'], batch_size, init_trans(image_size))
    # 2 创建网络、损失函数、优化器
    loss = nn.BCELoss()
    g_net = Generator(nzd, ngf, channel).to(device)
    g_optimizer = optim.Adam(g_net.parameters(), lr, betas=(beta1, 0.999))
    g_lr_scheduler = optim.lr_scheduler.StepLR(g_optimizer, step_size=8, gamma=0.1)
    d_net = Discriminator(ndf, channel).to(device)
    d_optimizer = optim.Adam(d_net.parameters(), lr, betas=(beta1, 0.999))
    d_lr_scheduler = optim.lr_scheduler.StepLR(d_optimizer, step_size=8, gamma=0.1)
    # 3 开始训练
    for epoch in range(1, epochs + 1):
        Len = len(train_iter)  # 内层循环的次数,每个epoch需要迭代的次数
        record_times = Len // record_loss_interval  # 计算一共需要记录多少次损失
        # 可以看到有4列,每一列代表的含义分别是:g_loss, D(x), D(G(x)), D(x)+D(G(x))
        loss_record = torch.zeros(size=(record_times, 4))
        cnt = 0
        print('#' * 30 + f'\n epoch={epoch} start \n' + '#' * 30)
        for i, data in enumerate(train_iter):
            data = data.to(device)
            ############################
            # (1) Update D network
            ###########################
            d_net.zero_grad()
            b_size = len(data)
            real_label = torch.full((b_size,), real_img_label, device=device)
            fake_label = torch.full((b_size,), fake_img_label, device=device)
            noise = torch.randn((b_size, nzd, 1, 1), device=device)
            # 通过noise生成一堆假图片fake_img.shape: (b_size, c=3, h=96, w=96)
            fake_img = g_net(noise)
            a = d_net(data)
            b = d_net(fake_img.detach())
            loss_d_real = loss(a.view(-1), real_label)
            loss_d_fake = loss(b.view(-1), fake_label)
            loss_d_real.backward()
            loss_d_fake.backward()
            d_optimizer.step()
            loss_d = loss_d_real + loss_d_fake  # 判别网络对真、假图片的损失和
            d_x = a.mean().item()  # D(x)是判别网络对真图片的打分均值
            d_g_x = b.mean().item()  # D(G(x))是判别网络对假图片的打分均值
            ############################
            # (2) Update G network
            ###########################
            g_net.zero_grad()
            # 经过上面的步骤,判别网络得到升级,再次去给假图片fake_img打分,来更新生成网络
            out_d_fake = d_net(fake_img)  # shape: (b_size, 1)
            loss_g = loss(out_d_fake.view(-1), real_label)
            loss_g.backward()
            g_optimizer.step()
            d_g_x2 = out_d_fake.mean().item()  # D(G(x2))

            if i % record_loss_interval == 0 and i != 0:
                print(f'[{epoch}/{epochs}]\t[{i}/{Len}]\t'
                      f'Loss_g={loss_g.item():.4f}\tD(x)={d_x}\tD(G(x))={d_g_x:.4f}/{d_g_x2:.4f}')
                # 记录损失
                loss_record[cnt] = torch.tensor([loss_g.item(), d_x, d_g_x, loss_d])
                cnt += 1

        # 每个epoch结束,学习率下降
        d_lr_scheduler.step()
        g_lr_scheduler.step()

        # 经过一个epoch后,使用fixed_noise去让g_net生成假图片,看看效果
        with torch.no_grad():
            fake = g_net(fixed_noise).detach().cpu()
        img_grid = vutils.make_grid(fake, padding=2, normalize=True).numpy()
        img_grid = np.transpose(img_grid, (1, 2, 0))
        plt.imsave(os.path.join(dirs['img'], f'{epoch}_epoch.png'), img_grid)

        # 将这一轮的损失使用visdom进行显示
        viz.line(Y=loss_record[:, [0, -1]].mean(keepdim=True, dim=0),
                 X=torch.full(size=(1, 2), fill_value=epoch),
                 win='g loss & d loss',
                 update='append',
                 opts=dict(title='Mean: g loss & d loss', legend=['g', 'd']))
        viz.line(Y=loss_record[:, [1, 2]].mean(keepdim=True, dim=0),
                 X=torch.full(size=(1, 2), fill_value=epoch),
                 win='D(x) & D(G(x))',
                 update='append',
                 opts=dict(title='Mean: D(x) & D(G(x))', legend=['D(x)', 'D(G(x))']))

        # 模型保存
        if epoch % checkpoint_interval == 0:
            checkpoint = {"g_model_state_dict": g_net.state_dict(),
                          "d_model_state_dict": d_net.state_dict(),
                          "epoch": epoch}
            path_checkpoint = os.path.join(dirs['model'], "checkpoint_{}_epoch.pkl".format(epoch))
            torch.save(checkpoint, path_checkpoint)

    # suffix是图片文件名后缀,上面代码中有一段是每个epoch生成一张图看效果,文件名是以_epoch.png结尾
    gen_gif(src=dirs['img'], suffix='_epoch.png', dst=dirs['img'], filename='动态图.gif')
    print("done")

3 效果

上图是基于3000张96*96的头像,训练20个epoch得到的,可以看到效果有在变好,但还是不够好,原因就是受限于GPU:

  • 原始数据集5w张96*96的图片,只用了3000张
  • batch_size=10,nzd=64, ngf=64,显存不够大,只能如此了
  • 由于epoch只有20可以考虑再增加一些

4 参考文档

[1] GAN学习指南:从原理入门到制作生成Demo
[2] DCGAN原理分析与代码解读
[3] Radford, A., et al. (2015) Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. arXiv:1511.06434

posted @ 2021-11-29 18:20  silverbeats  阅读(210)  评论(0编辑  收藏  举报