实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题
实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题
一、实践流程
1、数据预处理
主要是对训练数据进行随机偏移、转动等变换图像处理,这样可以尽可能让训练数据多样化
另外处理数据方式采用分批无序读取的形式,避免了数据按目录排序训练
- 
#数据准备
- 
def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
- 
if is_train:
- 
datagen = ImageDataGenerator(rescale=1./255,
- 
zoom_range=0.25, rotation_range=15.,
- 
channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
- 
horizontal_flip=True, fill_mode='constant')
- 
else:
- 
datagen = ImageDataGenerator(rescale=1./255)
- 
- 
generator = datagen.flow_from_directory(
- 
dir_path, target_size=(img_row, img_col),
- 
batch_size=batch_size,
- 
shuffle=is_train)
- 
- 
return generator
2、载入现有模型
这个部分是核心工作,目的是使用ImageNet训练出的权重来做我们的特征提取器,注意这里后面的分类层去掉
- 
base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
- 
input_shape=(img_rows, img_cols, color),
- 
classes=nb_classes)
然后是冻结这些层,因为是训练好的
- 
for layer in base_model.layers:
- 
layer.trainable = False
- 
x = base_model.output
- 
# 添加自己的全链接分类层
- 
x = GlobalAveragePooling2D()(x)
- 
x = Dense(1024, activation='relu')(x)
- 
predictions = Dense(nb_classes, activation='softmax')(x)
- 
x = base_model.output
- 
#添加自己的全链接分类层
- 
x = Flatten()(x)
- 
predictions = Dense(nb_classes, activation='softmax')(x)
3、训练模型
这里我们用fit_generator函数,它可以避免了一次性加载大量的数据,并且生成器与模型将并行执行以提高效率。比如可以在CPU上进行实时的数据提升,同时在GPU上进行模型训练
- 
history_ft = model.fit_generator(
- 
train_generator,
- 
steps_per_epoch=steps_per_epoch,
- 
epochs=epochs,
- 
validation_data=validation_generator,
- 
validation_steps=validation_steps)
二、猫狗大战数据集
训练数据540M,测试数据270M,大家可以去官网下载
https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
下载后把数据分成dog和cat两个目录来存放
三、训练
训练的时候会自动去下权值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我们已经下载好了的话,可以改源代码,让他直接读取我们的下载好的权值,比如在resnet50.py中
1、VGG19
vgg19的深度有26层,参数达到了549M,原模型最后有3个全连接层做分类器所以我还是加了一个1024的全连接层,训练10轮的情况达到了89%
2、ResNet50
ResNet50的深度达到了168层,但是参数只有99M,分类模型我就简单点,一层直接分类,训练10轮的达到了96%的准确率
3、inception_v3
InceptionV3的深度159层,参数92M,训练10轮的结果
这是一层直接分类的结果
这是加了一个512全连接的,大家可以随意调整测试
四、完整的代码
- 
# -*- coding: utf-8 -*-
- 
import os
- 
from keras.utils import plot_model
- 
from keras.applications.resnet50 import ResNet50
- 
from keras.applications.vgg19 import VGG19
- 
from keras.applications.inception_v3 import InceptionV3
- 
from keras.layers import Dense,Flatten,GlobalAveragePooling2D
- 
from keras.models import Model,load_model
- 
from keras.optimizers import SGD
- 
from keras.preprocessing.image import ImageDataGenerator
- 
import matplotlib.pyplot as plt
- 
- 
class PowerTransferMode:
- 
#数据准备
- 
def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
- 
if is_train:
- 
datagen = ImageDataGenerator(rescale=1./255,
- 
zoom_range=0.25, rotation_range=15.,
- 
channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
- 
horizontal_flip=True, fill_mode='constant')
- 
else:
- 
datagen = ImageDataGenerator(rescale=1./255)
- 
- 
generator = datagen.flow_from_directory(
- 
dir_path, target_size=(img_row, img_col),
- 
batch_size=batch_size,
- 
#class_mode='binary',
- 
shuffle=is_train)
- 
- 
return generator
- 
- 
#ResNet模型
- 
def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
- 
color = 3 if RGB else 1
- 
base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
- 
classes=nb_classes)
- 
- 
#冻结base_model所有层,这样就可以正确获得bottleneck特征
- 
for layer in base_model.layers:
- 
layer.trainable = False
- 
- 
x = base_model.output
- 
#添加自己的全链接分类层
- 
x = Flatten()(x)
- 
#x = GlobalAveragePooling2D()(x)
- 
#x = Dense(1024, activation='relu')(x)
- 
predictions = Dense(nb_classes, activation='softmax')(x)
- 
- 
#训练模型
- 
model = Model(inputs=base_model.input, outputs=predictions)
- 
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- 
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
- 
- 
#绘制模型
- 
if is_plot_model:
- 
plot_model(model, to_file='resnet50_model.png',show_shapes=True)
- 
- 
return model
- 
- 
- 
#VGG模型
- 
def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
- 
color = 3 if RGB else 1
- 
base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
- 
classes=nb_classes)
- 
- 
#冻结base_model所有层,这样就可以正确获得bottleneck特征
- 
for layer in base_model.layers:
- 
layer.trainable = False
- 
- 
x = base_model.output
- 
#添加自己的全链接分类层
- 
x = GlobalAveragePooling2D()(x)
- 
x = Dense(1024, activation='relu')(x)
- 
predictions = Dense(nb_classes, activation='softmax')(x)
- 
- 
#训练模型
- 
model = Model(inputs=base_model.input, outputs=predictions)
- 
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- 
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
- 
- 
# 绘图
- 
if is_plot_model:
- 
plot_model(model, to_file='vgg19_model.png',show_shapes=True)
- 
- 
return model
- 
- 
# InceptionV3模型
- 
def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
- 
is_plot_model=False):
- 
color = 3 if RGB else 1
- 
base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
- 
input_shape=(img_rows, img_cols, color),
- 
classes=nb_classes)
- 
- 
# 冻结base_model所有层,这样就可以正确获得bottleneck特征
- 
for layer in base_model.layers:
- 
layer.trainable = False
- 
- 
x = base_model.output
- 
# 添加自己的全链接分类层
- 
x = GlobalAveragePooling2D()(x)
- 
x = Dense(1024, activation='relu')(x)
- 
predictions = Dense(nb_classes, activation='softmax')(x)
- 
- 
# 训练模型
- 
model = Model(inputs=base_model.input, outputs=predictions)
- 
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- 
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
- 
- 
# 绘图
- 
if is_plot_model:
- 
plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
- 
- 
return model
- 
- 
#训练模型
- 
def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
- 
# 载入模型
- 
if is_load_model and os.path.exists(model_url):
- 
model = load_model(model_url)
- 
- 
history_ft = model.fit_generator(
- 
train_generator,
- 
steps_per_epoch=steps_per_epoch,
- 
epochs=epochs,
- 
validation_data=validation_generator,
- 
validation_steps=validation_steps)
- 
# 模型保存
- 
model.save(model_url,overwrite=True)
- 
return history_ft
- 
- 
# 画图
- 
def plot_training(self, history):
- 
acc = history.history['acc']
- 
val_acc = history.history['val_acc']
- 
loss = history.history['loss']
- 
val_loss = history.history['val_loss']
- 
epochs = range(len(acc))
- 
plt.plot(epochs, acc, 'b-')
- 
plt.plot(epochs, val_acc, 'r')
- 
plt.title('Training and validation accuracy')
- 
plt.figure()
- 
plt.plot(epochs, loss, 'b-')
- 
plt.plot(epochs, val_loss, 'r-')
- 
plt.title('Training and validation loss')
- 
plt.show()
- 
- 
- 
if __name__ == '__main__':
- 
image_size = 197
- 
batch_size = 32
- 
- 
transfer = PowerTransferMode()
- 
- 
#得到数据
- 
train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
- 
validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
- 
- 
#VGG19
- 
#model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
- 
#history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
- 
- 
#ResNet50
- 
model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
- 
history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
- 
- 
#InceptionV3
- 
#model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
- 
#history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
- 
- 
# 训练的acc_loss图
- 
transfer.plot_training(history_ft)
 
                    
                
 
 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号