GAN网络

解析一下GAN网络处理mnist图片数据集的代码

先看一下引入的包

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools

  前面都是比较常用的,

from tensorflow.keras.callbacks import EarlyStopping   提前终止训练,使用方法:
import keras
early_stopping=keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
                              patience=0, verbose=0, mode='auto',
                              baseline=None, restore_best_weights=False)
model.fit(callbacks = [early_stopping])
from IPython.display import Image  可以用来显示图片
import PIL 图像处理库  使用方法
//读取图片并显示
from PIL import Image, ImageDraw
#指定路径
sample_image_path = os.path.join(RAW_DATA_DIR, 'normal_1/images/img_0.png')
#读入图片
sample_image = Image.open(sample_image_path)
print  sample_image.format, "%dx%d" % sample_image.size, sample_image.mode
#输出
plt.title('Sample Image')
plt.imshow(sample_image)
plt.show()

   labelme是数据集标注用的

  glob是python自带的模块,可以支持find+通配符匹配的功能。使用方法:

 

 

  ---------------------------------------包介绍完毕了--------------------------------------------------------------

  首先看一下GAN类的成员变量

class GAN():
    def __init__(self,      #定义全局变量
                 ):
        self.img_shape = (28, 28, 1)  #输入图片28x28 
        self.save_path = r'./GAN.h5'   #模型保存的位置
        self.img_path = r'./photo'     #图片保存的位置
        self.batch_size = 20           #
        self.latent_dim = 100          #keras输入是100维度的张量
        self.sample_interval=1         #生成器生成图片的周期 和epoch有关
        self.epoch=10  #100
        #建立GAN模型的方法
        self.generator_model = self.build_generator()              #生成器对象
        self.discriminator_model = self.build_discriminator()      #判别器对象
        self.model = self.bulid_model()                 #GAN模型训练

  补充好注释之后很清晰了,我们先来看一下生成器和判别器的样子如下:可以看到他们就是两个神经网络,一个通过训练来分别真假的图片,一个通过训练试图混淆生成器的判断(生成更加真实的图片)。从代码上能够看出这两个东西中生成器似乎更加复杂一点,但是大体的结构是相似的。

    def build_generator(self):#生成器
        input=keras.Input(shape=self.latent_dim)
        x=layers.Dense(256)(input)
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.BatchNormalization(momentum=0.8)(x)
        x = layers.Dense(512)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)
        x = layers.Dense(1024)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)
        x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x)
        output=layers.Reshape(self.img_shape)(x)
        model=keras.Model(inputs=input,outputs=output,name='generator')
        model.summary()
        return model

    def build_discriminator(self):#判别器
        input=keras.Input(shape=self.img_shape)   #输入是图片
        x=layers.Flatten(input_shape=self.img_shape)(input)   #展开
        x=layers.Dense(512)(x)                   #全连接 
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.Dense(256)(x)
        x=layers.LeakyReLU(alpha=0.2)(x)
        output=layers.Dense(1,activation='sigmoid')(x)
        model=keras.Model(inputs=input,outputs=output,name='discriminator')
        model.summary()
        return model

  之后建立gan模型,input输入到生成器,生成器生成的图片是判别器的输入,判别器的输出是最终的输出。所以模型的总输入是input,总的输出是判别器的输出。这里判别器不训练就让生成器训练。

    def bulid_model(self):#建立GAN模型
        self.discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                                    metrics=['accuracy'])              #对判别器进行设置loss和优化器
 
        self.discriminator_model.trainable = False#使判别器不训练

        inputs = keras.Input(shape=self.latent_dim)                 #
        img = self.generator_model(inputs)                #
        outputs = self.discriminator_model(img)
        model = keras.Model(inputs=inputs, outputs=outputs)
        model.summary() #输出计算过程
        model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                      loss='binary_crossentropy',
                      )
        return model

  load data就是读取数据的函数,这里有用keras读取mnist数据集的方法,tf2版本不能够直接下载mnist数据集了所以这种办法还是可以使用的

    def load_data(self):
        (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
        train_images = train_images /255                  #将像素值归一化
        train_images = np.expand_dims(train_images, axis=3)     #在axis=3的维度后面增加一个维度 数值是1
        print('img_number:',train_images.shape)
        return train_images

  之后是最多的train部分的代码了,读数据、生成标签。计算好步长后按照之前设置的epoch进行训练,每个epoch将训练集进行打乱,generate_sample_images它的功能是让生成器根据噪声生成图片并保存,能让我们看到训练效果的渐变。之后在每一个step中,生成索引并通过索引读取到目标图片。生成高斯噪声并用生成器以此生成训练的gan_imgs。之后在这里训练判别器,之前我们建的valid对应真实的图片,fake对应生成器生成的图片,他们两个的loss是判别器的loss。同时,训练生成器,生成器通过噪声进行生成,每10次将判别器和生成器的loss输出一下,最后保存模型,over。

    def train(self):
        
        train_images=self.load_data()#读取数据
        #生成标签
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))
        step=int(train_images.shape[0]/self.batch_size)#计算步长
        print('step:',step)
        for epoch in range(self.epoch):            
            train_images = (tf.random.shuffle(train_images)).numpy()#每个epoch打乱一次
            if epoch % self.sample_interval == 0:
                self.generate_sample_images(epoch)

            for i in range(step):

                idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
                imgs =train_images[idx]#读取索引对应的图片
                noise = np.random.normal(0, 1, (self.batch_size, 100))  # 生成标准的高斯分布噪声
                gan_imgs = self.generator_model.predict(noise)#通过噪声生成图片
                #----------------------------------------------训练判别器
                discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid)  # 真实数据对应标签1
                discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake)  # 生成的数据对应标签0
                discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
                #----------------------------------------------- 训练生成器
                noise = np.random.normal(0, 1, (self.batch_size, 100))
                generator_loss = self.model.train_on_batch(noise, valid)
                if i%10==0:#每十步进行输出
                    print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
                        epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))

        self.model.save(self.save_path)#存储模型

  之后就是最后我们的pred的过程:

    def pred(self):#载入模型并生成图片
        model=keras.models.load_model(self.save_path)   #载入模型
        model.summary()                    #判别器参数报备
        noise = np.random.normal(0, 1, (1, self.latent_dim))  #来个噪声
        generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)   #输入输出
        generator.summary()                      #生成器参数报备
        img=np.squeeze(generator.predict([noise])) #删除所有1维度的条目  
        plt.imshow(img)
        plt.show()
        print(img.shape)

 

训练中遇到的问题:

1.训练困难,需要精心设计模型结构,并小心协调  和  的训练程度

  这个非常明显,很容易出现生成器和分类器有一个打败了另一方的情况,loss会突然变大然后恒定不变,处理的方案就是增加dropout的比率,并且降低batchsize 的值和学习率。

2.损失函数无法指示训练过程,缺乏一个有意义的指标和生成图片的质量相关联

  缺乏一个指标来判断训练的过程,我们只能通过loss来判断,但是实际上loss也不是一个合适的指标来判断模型的质量。

3.模式崩坏(mode collapse),生成的图片虽然看起来像是真的,但是缺乏多样性

  

之后可能要看一下pix2pix的代码看一下他的网络是什么样子的。

 
posted @ 2021-04-13 13:01  灰人  阅读(181)  评论(0)    收藏  举报