GAN网络
解析一下GAN网络处理mnist图片数据集的代码
先看一下引入的包
import numpy as np import matplotlib from matplotlib import pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers,optimizers,losses from tensorflow.keras.callbacks import EarlyStopping from tensorflow.python.keras import backend as K from tensorflow.keras.utils import plot_model from IPython.display import Image import cv2 import PIL import json, os import sys import labelme import labelme.utils as utils import glob import itertools
前面都是比较常用的,
from tensorflow.keras.callbacks import EarlyStopping 提前终止训练,使用方法:
import keras early_stopping=keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False) model.fit(callbacks = [early_stopping])
from IPython.display import Image 可以用来显示图片
import PIL 图像处理库 使用方法
//读取图片并显示 from PIL import Image, ImageDraw #指定路径 sample_image_path = os.path.join(RAW_DATA_DIR, 'normal_1/images/img_0.png') #读入图片 sample_image = Image.open(sample_image_path) print sample_image.format, "%dx%d" % sample_image.size, sample_image.mode #输出 plt.title('Sample Image') plt.imshow(sample_image) plt.show()
labelme是数据集标注用的
glob是python自带的模块,可以支持find+通配符匹配的功能。使用方法:

---------------------------------------包介绍完毕了--------------------------------------------------------------
首先看一下GAN类的成员变量
class GAN(): def __init__(self, #定义全局变量 ): self.img_shape = (28, 28, 1) #输入图片28x28 self.save_path = r'./GAN.h5' #模型保存的位置 self.img_path = r'./photo' #图片保存的位置 self.batch_size = 20 # self.latent_dim = 100 #keras输入是100维度的张量 self.sample_interval=1 #生成器生成图片的周期 和epoch有关 self.epoch=10 #100 #建立GAN模型的方法 self.generator_model = self.build_generator() #生成器对象 self.discriminator_model = self.build_discriminator() #判别器对象 self.model = self.bulid_model() #GAN模型训练
补充好注释之后很清晰了,我们先来看一下生成器和判别器的样子如下:可以看到他们就是两个神经网络,一个通过训练来分别真假的图片,一个通过训练试图混淆生成器的判断(生成更加真实的图片)。从代码上能够看出这两个东西中生成器似乎更加复杂一点,但是大体的结构是相似的。
def build_generator(self):#生成器 input=keras.Input(shape=self.latent_dim) x=layers.Dense(256)(input) x=layers.LeakyReLU(alpha=0.2)(x) x=layers.BatchNormalization(momentum=0.8)(x) x = layers.Dense(512)(x) x = layers.LeakyReLU(alpha=0.2)(x) x = layers.BatchNormalization(momentum=0.8)(x) x = layers.Dense(1024)(x) x = layers.LeakyReLU(alpha=0.2)(x) x = layers.BatchNormalization(momentum=0.8)(x) x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x) output=layers.Reshape(self.img_shape)(x) model=keras.Model(inputs=input,outputs=output,name='generator') model.summary() return model def build_discriminator(self):#判别器 input=keras.Input(shape=self.img_shape) #输入是图片 x=layers.Flatten(input_shape=self.img_shape)(input) #展开 x=layers.Dense(512)(x) #全连接 x=layers.LeakyReLU(alpha=0.2)(x) x=layers.Dense(256)(x) x=layers.LeakyReLU(alpha=0.2)(x) output=layers.Dense(1,activation='sigmoid')(x) model=keras.Model(inputs=input,outputs=output,name='discriminator') model.summary() return model
之后建立gan模型,input输入到生成器,生成器生成的图片是判别器的输入,判别器的输出是最终的输出。所以模型的总输入是input,总的输出是判别器的输出。这里判别器不训练就让生成器训练。
def bulid_model(self):#建立GAN模型 self.discriminator_model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0001, 0.000001), metrics=['accuracy']) #对判别器进行设置loss和优化器 self.discriminator_model.trainable = False#使判别器不训练 inputs = keras.Input(shape=self.latent_dim) # img = self.generator_model(inputs) # outputs = self.discriminator_model(img) model = keras.Model(inputs=inputs, outputs=outputs) model.summary() #输出计算过程 model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001), loss='binary_crossentropy', ) return model
load data就是读取数据的函数,这里有用keras读取mnist数据集的方法,tf2版本不能够直接下载mnist数据集了所以这种办法还是可以使用的
def load_data(self): (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data() train_images = train_images /255 #将像素值归一化 train_images = np.expand_dims(train_images, axis=3) #在axis=3的维度后面增加一个维度 数值是1 print('img_number:',train_images.shape) return train_images
之后是最多的train部分的代码了,读数据、生成标签。计算好步长后按照之前设置的epoch进行训练,每个epoch将训练集进行打乱,generate_sample_images它的功能是让生成器根据噪声生成图片并保存,能让我们看到训练效果的渐变。之后在每一个step中,生成索引并通过索引读取到目标图片。生成高斯噪声并用生成器以此生成训练的gan_imgs。之后在这里训练判别器,之前我们建的valid对应真实的图片,fake对应生成器生成的图片,他们两个的loss是判别器的loss。同时,训练生成器,生成器通过噪声进行生成,每10次将判别器和生成器的loss输出一下,最后保存模型,over。
def train(self): train_images=self.load_data()#读取数据 #生成标签 valid = np.ones((self.batch_size, 1)) fake = np.zeros((self.batch_size, 1)) step=int(train_images.shape[0]/self.batch_size)#计算步长 print('step:',step) for epoch in range(self.epoch): train_images = (tf.random.shuffle(train_images)).numpy()#每个epoch打乱一次 if epoch % self.sample_interval == 0: self.generate_sample_images(epoch) for i in range(step): idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引 imgs =train_images[idx]#读取索引对应的图片 noise = np.random.normal(0, 1, (self.batch_size, 100)) # 生成标准的高斯分布噪声 gan_imgs = self.generator_model.predict(noise)#通过噪声生成图片 #----------------------------------------------训练判别器 discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid) # 真实数据对应标签1 discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake) # 生成的数据对应标签0 discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake) #----------------------------------------------- 训练生成器 noise = np.random.normal(0, 1, (self.batch_size, 100)) generator_loss = self.model.train_on_batch(noise, valid) if i%10==0:#每十步进行输出 print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % ( epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss)) self.model.save(self.save_path)#存储模型
之后就是最后我们的pred的过程:
def pred(self):#载入模型并生成图片 model=keras.models.load_model(self.save_path) #载入模型 model.summary() #判别器参数报备 noise = np.random.normal(0, 1, (1, self.latent_dim)) #来个噪声 generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output) #输入输出 generator.summary() #生成器参数报备 img=np.squeeze(generator.predict([noise])) #删除所有1维度的条目 plt.imshow(img) plt.show() print(img.shape)
训练中遇到的问题:
1.训练困难,需要精心设计模型结构,并小心协调 和 的训练程度
这个非常明显,很容易出现生成器和分类器有一个打败了另一方的情况,loss会突然变大然后恒定不变,处理的方案就是增加dropout的比率,并且降低batchsize 的值和学习率。
2.损失函数无法指示训练过程,缺乏一个有意义的指标和生成图片的质量相关联
缺乏一个指标来判断训练的过程,我们只能通过loss来判断,但是实际上loss也不是一个合适的指标来判断模型的质量。
3.模式崩坏(mode collapse),生成的图片虽然看起来像是真的,但是缺乏多样性
之后可能要看一下pix2pix的代码看一下他的网络是什么样子的。

浙公网安备 33010602011771号