Python机器学习设计——果蔬识别
一、选题背景
在学习了深度学习案例——MNIST手写数字识别和基于keras框架的猫狗图像识别,为了进一步熟悉tensorflow和keras的基本用法和网络框架,试想着实现对水果和蔬菜的种类进行识别。
二、机器学习案例设计方案
1.数据集来源
水果蔬菜数据集:
数据集来源于kaggle,因数据集作者为了构建一个应用程序,该应用程序可以从捕获的照片中识别食品,并为用户提供可以使用食品制作的不同食谱从必应图像搜索中抓取的,共包含4291张图像,36个种类,分为train(每个种类100张),test(每个种类10张),validation(每个种类10张)三个文件夹,上述每个文件夹都包含不同水果和蔬菜的子文件夹,其中存在相应食品的图像,包含水果10种,蔬菜26种,共36种。
2.采用的机器学习框架描述
卷积神经网络(Convolutional Neural Network, CNN):卷积神经网络是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的 ImageNet 数据集上,许多成功的模型都是基于 CNN 的。
Tensorflow:TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,TensorFlow为张量从流图的一端流动到另一端计算过程。TensorFlow是将复杂的数据结构传输至人工智能神经网中进行分析和处理过程的系统,可被用于语音识别或图像识别等多项机器学习和深度学习领域。
Keras:Keras是一个模型级( model-level)的库,为开发深度学习模型提供了高层次的构建模块。它不处理张量操作、求微积分等基础的运算,而是依赖--个专门的、高度优化的张量库来完成这些运算。这个张量库就是Keras的后端引擎(backend engine),如TensorFlow等。
3.涉及到的技术难点和解决思路
一些第三方库没有此模块的问题 解决思路:降版本或者使用功能相同的新模块
三、机器学习的实现步骤
果蔬识别
1.对原有数据集进行分类
1)将数据集train,test,和validation分类成fruit和vegetable两类
1 import os 2 3 import shutil 4 5 #训练集数据处理 6 def train(file_path): 7 d=[] 8 s=[] 9 for root, dirs , files in os.walk(file_path): #读取文件并提取出文件路径和类名 10 d.append(root)#文件路径 11 s.append(dirs)#类名 12 for i in s: 13 if i!=s[0]: 14 s.remove(i) 15 b = [ i for item in s for i in item] 16 d.pop(0) 17 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 18 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 19 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 20 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 21 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 22 #用for循环对水果和蔬菜进行分类 23 #训练集数据处理 24 #水果 25 for i in b: 26 for j in fruit: 27 if i==j: 28 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/train",i), 29 "C:/Users/linyicheng/Desktop/fruit/train") 30 #蔬菜 31 for i in b: 32 for j in vegetables: 33 if i==j: 34 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/train",i), 35 "C:/Users/linyicheng/Desktop/vegetable/train") 36 37 train("C:/Users/linyicheng/Desktop/train")
1 #验证集数据处理 2 def validation(file_path): 3 d=[] 4 s=[] 5 for root, dirs , files in os.walk(file_path): 6 d.append(root) 7 s.append(dirs) 8 for i in s: 9 if i!=s[0]: 10 s.remove(i) 11 b = [ i for item in s for i in item] 12 d.pop(0) 13 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 14 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 15 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 16 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 17 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 18 #训练集数据处理 19 #水果 20 for i in b: 21 for j in fruit: 22 if i==j: 23 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/validation",i), 24 "C:/Users/linyicheng/Desktop/fruit/validation") 25 #蔬菜 26 for i in b: 27 for j in vegetables: 28 if i==j: 29 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/validation",i), 30 "C:/Users/linyicheng/Desktop/vegetable/validation") 31 32 validation("C:/Users/linyicheng/Desktop/validation")
1 #测试集数据处理 2 def test(file_path): 3 d=[] 4 s=[] 5 for root, dirs , files in os.walk(file_path): 6 d.append(root) 7 s.append(dirs) 8 for i in s: 9 if i!=s[0]: 10 s.remove(i) 11 b = [ i for item in s for i in item] 12 d.pop(0) 13 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 14 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 15 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 16 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 17 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 18 #训练集数据处理 19 #水果 20 for i in b: 21 for j in fruit: 22 if i==j: 23 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/test",i), 24 "C:/Users/linyicheng/Desktop/fruit/test") 25 #蔬菜 26 for i in b: 27 for j in vegetables: 28 if i==j: 29 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/test",i), 30 "C:/Users/linyicheng/Desktop/vegetable/test") 31 32 test("C:/Users/linyicheng/Desktop/test")
2.对数据集进行预处理,模仿猫狗识别建立数据集目录
1)对fruit,vegetable里的train,test,validation进行重命名
1 #水果 2 outer_path ="C:/Users/linyicheng/Desktop/fruit/test" 3 folderlist = os.listdir(outer_path) #列举文件夹 4 5 for folder in folderlist: 6 inner_path = os.path.join(outer_path, folder) 7 total_num_folder = len(folderlist) #文件夹的总数 8 #打印文件夹的总数 9 filelist = os.listdir(inner_path) #列举图片 10 i = 0 11 for item in filelist: 12 total_num_file = len(filelist) #单个文件夹内图片的总数 13 if item.endswith('.jpg'): 14 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 15 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 16 try: 17 os.rename(src, dst) 18 i += 1 19 except: 20 continue
1 outer_path ="C:/Users/linyicheng/Desktop/fruit/train" 2 folderlist = os.listdir(outer_path) #列举文件夹 3 4 for folder in folderlist: 5 inner_path = os.path.join(outer_path, folder) 6 total_num_folder = len(folderlist) #文件夹的总数 7 #打印文件夹的总数 8 9 filelist = os.listdir(inner_path) #列举图片 10 i = 0 11 for item in filelist: 12 total_num_file = len(filelist) #单个文件夹内图片的总数 13 if item.endswith('.jpg'): 14 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 15 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 16 try: 17 os.rename(src, dst) 18 19 i += 1 20 except: 21 continue
1 #验证集 2 outer_path ="C:/Users/linyicheng/Desktop/fruit/validation" 3 folderlist = os.listdir(outer_path) #列举文件夹 4 5 for folder in folderlist: 6 inner_path = os.path.join(outer_path, folder) 7 total_num_folder = len(folderlist) #文件夹的总数 8 #打印文件夹的总数 9 10 filelist = os.listdir(inner_path) #列举图片 11 i = 0 12 for item in filelist: 13 total_num_file = len(filelist) #单个文件夹内图片的总数 14 if item.endswith('.jpg'): 15 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 16 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 17 try: 18 os.rename(src, dst) 19 20 i += 1 21 except: 22 continue
1 #蔬菜 2 outer_path ="C:/Users/linyicheng/Desktop/vegetable/test" 3 folderlist = os.listdir(outer_path) #列举文件夹 4 5 for folder in folderlist: 6 inner_path = os.path.join(outer_path, folder) 7 total_num_folder = len(folderlist) #文件夹的总数 8 #打印文件夹的总数 9 filelist = os.listdir(inner_path) #列举图片 10 i = 0 11 for item in filelist: 12 total_num_file = len(filelist) #单个文件夹内图片的总数 13 if item.endswith('.jpg'): 14 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 15 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 16 try: 17 os.rename(src, dst) 18 i += 1 19 except: 20 continue
1 outer_path ="C:/Users/linyicheng/Desktop/vegetable/train" 2 folderlist = os.listdir(outer_path) #列举文件夹 3 4 for folder in folderlist: 5 inner_path = os.path.join(outer_path, folder) 6 total_num_folder = len(folderlist) #文件夹的总数 7 #打印文件夹的总数 8 9 filelist = os.listdir(inner_path) #列举图片 10 i = 0 11 for item in filelist: 12 total_num_file = len(filelist) #单个文件夹内图片的总数 13 if item.endswith('.jpg'): 14 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 15 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 16 try: 17 os.rename(src, dst) 18 19 i += 1 20 except: 21 continue
1 outer_path ="C:/Users/linyicheng/Desktop/vegetable/validation" 2 folderlist = os.listdir(outer_path) #列举文件夹 3 4 for folder in folderlist: 5 inner_path = os.path.join(outer_path, folder) 6 total_num_folder = len(folderlist) #文件夹的总数 7 #打印文件夹的总数 8 9 filelist = os.listdir(inner_path) #列举图片 10 i = 0 11 for item in filelist: 12 total_num_file = len(filelist) #单个文件夹内图片的总数 13 if item.endswith('.jpg'): 14 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 15 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 16 try: 17 os.rename(src, dst) 18 19 i += 1 20 except: 21 continue
2)将文件移到base目录中,构建数据集目录结构
1 #目标文件夹,此处为相对路径,也可以改为绝对路径 2 import shutil 3 determination = "C:/Users/linyicheng/Desktop/base/test/fruit"#目标 4 if not os.path.exists(determination): 5 os.makedirs(determination) 6 7 #源文件夹路径 8 path = "C:/Users/linyicheng/Desktop/fruit/test" 9 folders = os.listdir(path) 10 for folder in folders: 11 dir = path + '/' + str(folder) 12 files = os.listdir(dir) 13 for file in files: 14 source = dir + '/' + str(file) 15 deter = determination + '/' + str(file) 16 shutil.copyfile(source, deter)
1 determination = "C:/Users/linyicheng/Desktop/base/train/fruit" 2 if not os.path.exists(determination): 3 os.makedirs(determination) 4 path = "C:/Users/linyicheng/Desktop/fruit/train" 5 folders = os.listdir(path) 6 for folder in folders: 7 dir = path + '/' + str(folder) 8 files = os.listdir(dir) 9 for file in files: 10 source = dir + '/' + str(file) 11 deter = determination + '/' + str(file) 12 shutil.copyfile(source, deter)
1 determination = "C:/Users/linyicheng/Desktop/base/validation/fruit" 2 if not os.path.exists(determination): 3 os.makedirs(determination) 4 path = "C:/Users/linyicheng/Desktop/fruit/validation" 5 folders = os.listdir(path) 6 for folder in folders: 7 dir = path + '/' + str(folder) 8 files = os.listdir(dir) 9 for file in files: 10 source = dir + '/' + str(file) 11 deter = determination + '/' + str(file) 12 shutil.copyfile(source, deter)
1 import shutil 2 determination = "C:/Users/linyicheng/Desktop/base/test/vegetable"#目标 3 if not os.path.exists(determination): 4 os.makedirs(determination) 5 6 #源文件夹路径 7 path = "C:/Users/linyicheng/Desktop/vegetable/test" 8 folders = os.listdir(path) 9 for folder in folders: 10 dir = path + '/' + str(folder) 11 files = os.listdir(dir) 12 for file in files: 13 source = dir + '/' + str(file) 14 deter = determination + '/' + str(file) 15 shutil.copyfile(source, deter)
1 determination = "C:/Users/linyicheng/Desktop/base/train/vegetable" 2 if not os.path.exists(determination): 3 os.makedirs(determination) 4 path = "C:/Users/linyicheng/Desktop/vegetable/train" 5 folders = os.listdir(path) 6 for folder in folders: 7 dir = path + '/' + str(folder) 8 files = os.listdir(dir) 9 for file in files: 10 source = dir + '/' + str(file) 11 deter = determination + '/' + str(file) 12 shutil.copyfile(source, deter)
1 determination = "C:/Users/linyicheng/Desktop/base/validation/vegetable" 2 if not os.path.exists(determination): 3 os.makedirs(determination) 4 path = "C:/Users/linyicheng/Desktop/vegetable/validation" 5 folders = os.listdir(path) 6 for folder in folders: 7 dir = path + '/' + str(folder) 8 files = os.listdir(dir) 9 for file in files: 10 source = dir + '/' + str(file) 11 deter = determination + '/' + str(file) 12 shutil.copyfile(source, deter)
3.设置路径
1 base_dir="C:/Users/linyicheng/Desktop/base/" 2 train_data_dir="C:/Users/linyicheng/Desktop/base/train/" 3 test_data_dir="C:/Users/linyicheng/Desktop/base/test/" 4 val_data_dir="C:/Users/linyicheng/Desktop/base/validation/" 5 train_fruit_data="C:/Users/linyicheng/Desktop/base/train/fruit/" 6 test_fruit_data="C:/Users/linyicheng/Desktop/base/test/fruit/" 7 test_vegetable_data="C:/Users/linyicheng/Desktop/base/test/vegetable/"
4.搭建卷积神经网络
1 #搭建卷积神经网络 2 from keras import layers 3 from keras import models 4 model=models.Sequential() 5 """ 6 Output shape计算公式:(输入尺寸-卷积核尺寸/步长+1 7 对CNN模型,Param的计算方法如下: 8 卷积核长度*卷积核宽度*通道数+1)*卷积核个数 9 输出图片尺寸:150-3+1=148*148 10 参数数量:32*3*3*3+32=896 11 """ 12 model.add(layers.Conv2D(32,(3,3),activation='relu',input_shape=(150,150,3)))#卷积层1 13 model.add(layers.MaxPooling2D(2,2))#最大池化层1 14 # 输出图片尺寸:148/2=74*74 15 # 输出图片尺寸:74-3+1=72*72,参数数量:64*3*3*32+64=18496 16 #32是第1个卷积层的输出的通道数 17 model.add(layers.Conv2D(64, (3, 3), activation='relu'))#卷积层2 18 model.add(layers.MaxPooling2D((2, 2)))#最大池化层2 19 # 输出图片尺寸:72/2=36*36 20 21 #Output Shape的输出为36 22 # 输出图片尺寸:36-3+1=34*34,参数数量:128*3*3*64+128=73856 23 model.add(layers.Conv2D(128, (3, 3), activation='relu'))#卷积层3 24 model.add(layers.MaxPooling2D((2, 2)))#最大池化层3 25 # 输出图片尺寸:34/2=17*17 26 # 输出图片尺寸:17-3+1=15*15,参数数量:128*3*3*128+128=147584 27 model.add(layers.Conv2D(128, (3, 3), activation='relu'))# 输出图片尺寸:15/2=7*7 28 model.add(layers.MaxPooling2D((2, 2))) 29 model.add(layers.Flatten()) 30 model.add(layers.Dense(512, activation='relu')) #全连接层1 31 model.add(layers.Dense(1, activation='sigmoid'))#全连接层2,作为输出层
从输出看出,卷积神经网络dense_1 (Dense)的参数总数达到300多万
1 model.summary()
5.编译模型
1 # 编译模型 2 # RMSprop 优化器。因为网络最后一层是单一sigmoid单元 3 from keras import optimizers 4 model.compile(loss='binary_crossentropy', 5 optimizer=optimizers.RMSprop(lr=1e-4), 6 metrics=['acc'])
6.使用 ImageDataGenerator 从目录中读取样本数据
1 from keras.preprocessing.image import ImageDataGenerator 2 #归一化 3 4 train_datagen = ImageDataGenerator(rescale=1./255) 5 validation_datagen=ImageDataGenerator(rescale=1./255) 6 test_datagen = ImageDataGenerator(rescale=1./255) 7 # 数据集加载函数,指明数据集的位置并统一处理为imgheight*imgwidth的大小,同时设置batch 8 # 加载训练集 9 train_ds = train_datagen.flow_from_directory ( 10 #文件目录位置 11 train_data_dir, 12 #输入训练图像尺寸,所有图片的size必须是150x150 13 class_mode='binary', 14 target_size=(150, 150), 15 batch_size=20) 16 #加载测试集 17 test_ds = test_datagen.flow_from_directory( 18 test_data_dir, 19 class_mode='binary', 20 target_size=(150, 150), 21 batch_size=20) 22 # 加载验证集 23 val_ds =test_datagen.flow_from_directory( 24 val_data_dir, 25 class_mode='binary', 26 target_size=(150, 150), 27 batch_size=20)
1 for data_batch, labels_batch in train_ds: 2 print('data batch shape:', data_batch.shape) 3 print('labels batch shape:', labels_batch.shape) 4 break
7.训练模型(30次)
1 #训练模型30轮次,可以修改epochs的值 2 history = model.fit_generator( 3 train_ds, 4 steps_per_epoch=100, 5 epochs=30, 6 validation_data=val_ds, 7 validation_steps=50)
可以看出训练集精度(acc):0.9130 和验证集精度(val_acc):0.9803
模型较为成功,接下来通过数据增强来提升精度
1 #将训练过程产生的数据保存为h5文件
2 model.save('fruit_and_vegetable_30epoch.h5')
8.绘制训练过程中的损失曲线和精度曲线
1 import matplotlib.pyplot as plt 2 3 acc = history.history['acc'] 4 val_acc = history.history['val_acc'] 5 loss = history.history['loss'] 6 val_loss = history.history['val_loss'] 7 8 epochs = range(1, len(acc) + 1) 9 10 plt.plot(epochs, acc, 'bo', label='Training acc') 11 plt.plot(epochs, val_acc, 'b', label='Validation acc') 12 plt.title('Training and validation accuracy') 13 plt.legend() 14 plt.figure() 15 plt.plot(epochs, loss, 'bo', label='Training loss') 16 plt.plot(epochs, val_loss, 'b', label='Validation loss') 17 plt.title('Training and validation loss') 18 plt.legend() 19 20 plt.show()
由图可知训练精度随着时间线性增加,直到接近100%
9.单张图片进行判断图片是水果还是蔬菜
1 from PIL import Image 2 from keras.preprocessing import image 3 from keras.models import load_model 4 import numpy as np 5 #加载模型 6 model=load_model('fruit_and_vegetable_30epoch.h5') 7 #本地图片路径 8 img_path=("C:/Users/linyicheng/Desktop/base/train/fruit/apple_41.jpg") 9 img = image.load_img(img_path,target_size=(150,150)) 10 # 将其转换为具有形状的Numpy数组(150、150、3) 11 img_tensor = image.img_to_array(img)/255 12 # 将其形态变为(150,150,150,3)的形状 13 img_tensor = np.expand_dims(img_tensor, axis=0) 14 #取图片信息 15 prediction =model.predict(img_tensor) 16 #输出识别率 17 print(prediction) 18 if prediction>0.5: 19 print('水果') 20 else: 21 print('蔬菜') 22 plt.imshow(img)
10. 数据增强训练,利用ImageDataGenerator实现数据增强
1 from keras.preprocessing.image import ImageDataGenerator 2 from keras.models import load_model 3 from tensorflow.python.keras.preprocessing import image 4 from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img 5 6 import os 7 from keras.preprocessing.image import load_img 8 import matplotlib.pyplot as plt 9 figure,ax = plt.subplots(nrows=1, ncols=4, sharex=True, 10 sharey=True,figsize=(16,10)) 11 12 datagen=ImageDataGenerator( 13 rotation_range=40, #图像随机旋转角度范围 14 width_shift_range=0.2, #图片在水平方向上平移的比例 15 height_shift_range=0.2, #图片在垂直方向上平移的比例 16 shear_range=0.2, #随机错切变换的角度 17 zoom_range=0.2, #图像随机缩放的范围 18 horizontal_flip=True, #随机将一半图像水平翻转 19 20 fill_mode='nearest') 21 #填充创建像素的一种方法 22 fnames=[os.path.join(train_fruit_data,fname)for fname in os.listdir(train_fruit_data)] 23 img_path=fnames[6] 24 img=image.load_img(img_path,target_size=(150,150)) 25 x = image.img_to_array(img)/255 26 x = x.reshape((1,) + x.shape) 27 i = 0 28 for batch in datagen.flow(x, batch_size=1): 29 ax = ax.flatten() # 将子图从多维变成一维 30 plt.figure(i) 31 imgplot = ax[i-1].imshow(image.array_to_img(batch[0]), cmap='Greys', interpolation='nearest') 32 i += 1 33 if i % 4 == 0: 34 break 35 plt.show() #绘制增强后的图像 36 ''' 37 虽然使用了数据增强技术,但是从输出的图像来看,图片之间还是有很大的相似度, 38 因为它们均来自同一张原始图片,并没有提供新的信息。 39 为了尽可能消除过拟合,可以在模型中增加一个DDropout层,添加到密集连接分类 40 器前 41 '''
11. 在紧密连接的分类器之前为模型添加一个Dropout层
1 #在紧密连接的分类器之前为模型添加一个Dropout层 2 model = models.Sequential() 3 model.add(layers.Conv2D(32, (3, 3), activation='relu', 4 input_shape=(150, 150, 3))) 5 model.add(layers.MaxPooling2D((2, 2))) 6 model.add(layers.Conv2D(64, (3, 3), activation='relu')) 7 model.add(layers.MaxPooling2D((2, 2))) 8 model.add(layers.Conv2D(128, (3, 3), activation='relu')) 9 model.add(layers.MaxPooling2D((2, 2))) 10 model.add(layers.Conv2D(128, (3, 3), activation='relu')) 11 model.add(layers.MaxPooling2D((2, 2))) 12 model.add(layers.Flatten()) 13 model.add(layers.Dropout(0.5)) 14 model.add(layers.Dense(512, activation='relu')) 15 model.add(layers.Dense(1, activation='sigmoid'))
12.编译模型
1 from keras import optimizers 2 model.compile(loss='binary_crossentropy', 3 optimizer=optimizers.RMSprop(lr=1e-4), 4 metrics=['acc'])
13.使用数据增强和dropout来训练我们的网络
1 #使用数据增强和dropout来训练我们的网络 2 #归一化 3 train_datagen = ImageDataGenerator( 4 rescale=1./255, 5 rotation_range=40, 6 width_shift_range=0.2, 7 height_shift_range=0.2, 8 shear_range=0.2, 9 zoom_range=0.2, 10 horizontal_flip=True) 11 test_datagen = ImageDataGenerator(rescale=1./255) 12 train_generator = train_datagen.flow_from_directory( 13 train_data_dir, 14 #输入训练图像尺寸,所有图片的size必须是150x150 15 target_size=(150, 150), 16 batch_size=32, 17 #因为我们使用二元交叉熵损失,我们需要二元标签 18 class_mode='binary') 19 validation_generator = test_datagen.flow_from_directory( 20 val_data_dir, 21 target_size=(150, 150), 22 batch_size=32, 23 class_mode='binary')
14.训练模型(100次)
1 history = model.fit_generator( 2 train_generator, 3 steps_per_epoch=100, 4 epochs=100, 5 validation_data=validation_generator, 6 validation_steps=50)
1 model.save('fruit_and_vegetable_100epoch.h5')
15.绘制训练过程中的损失曲线和精度曲线
1 import matplotlib.pyplot as plt 2 3 acc = history.history['acc'] 4 val_acc = history.history['val_acc'] 5 loss = history.history['loss'] 6 val_loss = history.history['val_loss'] 7 8 epochs = range(1, len(acc) + 1) 9 10 plt.plot(epochs, acc, 'bo', label='Training acc') 11 plt.plot(epochs, val_acc, 'b', label='Validation acc') 12 plt.title('Training and validation accuracy') 13 plt.legend() 14 plt.figure() 15 plt.plot(epochs, loss, 'bo', label='Training loss') 16 plt.plot(epochs, val_loss, 'b', label='Validation loss') 17 plt.title('Training and validation loss') 18 plt.legend() 19 20 plt.show()
16.使用新的训练模型进行识别
1 from keras.preprocessing import image 2 3 from keras.models import load_model 4 import numpy as np 5 #加载模型 6 model=load_model('fruit_and_vegetable_100epoch.h5') 7 #本地图片路径 8 img_path=("C:/Users/linyicheng/Desktop/base/train/fruit/apple_42.jpg") 9 img = image.load_img(img_path,target_size=(150,150)) 10 # 将其转换为具有形状的Numpy数组(150、150、3) 11 img_tensor = image.img_to_array(img)/255 12 # 将其形态变为(150,150,150,3)的形状 13 img_tensor = np.expand_dims(img_tensor, axis=0) 14 #取图片信息 15 prediction =model.predict(img_tensor) 16 #输出识别率 17 print(prediction) 18 if prediction>0.5: 19 print('水果') 20 else: 21 print('蔬菜') 22 plt.imshow(img)
17.多张图片进行识别判断
1 from PIL import Image 2 from keras.preprocessing import image 3 import matplotlib.pyplot as plt 4 from keras.models import load_model 5 plt.rcParams['font.family'] = ['SimHei'] 6 figure,ax = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True,figsize=(24,10)) 7 def convertjpg(jpgfile): 8 #将图片缩小到(150,150)的大小 9 img=Image.open(jpgfile) 10 try: 11 new_img=img.resize((150,150),Image.BILINEAR) 12 return new_img 13 except Exception as e: 14 print(e) 15 snames=[os.path.join(test_vegetable_data,sname)for sname in os.listdir(test_vegetable_data)] 16 j=0 17 for i in snames[0:8] : 18 j+=1 19 img_show=convertjpg(i) 20 img_scale = image.img_to_array(img_show) 21 img_scale = img_scale.reshape(1,150,150,3) # 将形状转化为(1,150,150,3) 22 img_scale = img_scale.astype('float32')/255 23 result = model.predict(img_scale)# 预测函数- 24 if result>0.5: 25 title = '蔬菜' 26 label = '正确识别率为:'+str(result) 27 else: 28 title = '水果' 29 label = '正确识别率为:'+str(1-result) 30 31 ax = ax.flatten() # 将子图从多维变成一维 32 ax[j-1].imshow(img_show, cmap='Greys', interpolation='nearest') 33 # 子图标题 34 ax[j-1].set_title(title,fontsize=24) 35 # 子图X轴标签 36 ax[j-1].set_xlabel(label,fontsize=24) 37 # 去掉刻度 38 ax[0].set_xticks([]) 39 ax[0].set_yticks([]) 40 plt.show() 41 42 figure,ax = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True,figsize=(24,10)) 43 44 fnames=[os.path.join(test_fruit_data,fname)for fname in os.listdir(test_fruit_data)] 45 j=0 46 for i in fnames[0:8]: 47 j+=1 48 img_show=convertjpg(i) 49 img_scale = image.img_to_array(img_show) 50 img_scale = img_scale.reshape(1,150,150,3) # 将形状转化为(1,150,150,3) 51 img_scale = img_scale.astype('float32')/255 52 result = model.predict(img_scale)# 预测函数- 53 if result>0.5: 54 title = '蔬菜' 55 label = '正确识别率为:'+ str(result) 56 else: 57 title = '水果' 58 label = '正确识别率为:'+str(1-result) 59 ax = ax.flatten() # 将子图从多维变成一维 60 ax[j-1].imshow(img_show, cmap='Greys', interpolation='nearest') 61 # 子图标题 62 ax[j-1].set_title(title,fontsize=24) 63 # 子图X轴标签 64 ax[j-1].set_xlabel(label,fontsize=24) 65 # 去掉刻度 66 ax[0].set_xticks([]) 67 ax[0].set_yticks([])
全代码附上
1 import os 2 3 import shutil 4 5 #训练集数据处理 6 def train(file_path): 7 d=[] 8 s=[] 9 for root, dirs , files in os.walk(file_path): #读取文件并提取出文件路径和类名 10 d.append(root)#文件路径 11 s.append(dirs)#类名 12 for i in s: 13 if i!=s[0]: 14 s.remove(i) 15 b = [ i for item in s for i in item] 16 d.pop(0) 17 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 18 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 19 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 20 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 21 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 22 #用for循环对水果和蔬菜进行分类 23 #训练集数据处理 24 #水果 25 for i in b: 26 for j in fruit: 27 if i==j: 28 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/train",i), 29 "C:/Users/linyicheng/Desktop/fruit/train") 30 #蔬菜 31 for i in b: 32 for j in vegetables: 33 if i==j: 34 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/train",i), 35 "C:/Users/linyicheng/Desktop/vegetable/train") 36 37 train("C:/Users/linyicheng/Desktop/train") 38 #验证集数据处理 39 def validation(file_path): 40 d=[] 41 s=[] 42 for root, dirs , files in os.walk(file_path): 43 d.append(root) 44 s.append(dirs) 45 for i in s: 46 if i!=s[0]: 47 s.remove(i) 48 b = [ i for item in s for i in item] 49 d.pop(0) 50 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 51 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 52 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 53 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 54 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 55 #训练集数据处理 56 #水果 57 for i in b: 58 for j in fruit: 59 if i==j: 60 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/validation",i), 61 "C:/Users/linyicheng/Desktop/fruit/validation") 62 #蔬菜 63 for i in b: 64 for j in vegetables: 65 if i==j: 66 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/validation",i), 67 "C:/Users/linyicheng/Desktop/vegetable/validation") 68 69 validation("C:/Users/linyicheng/Desktop/validation") 70 #测试集数据处理 71 def test(file_path): 72 d=[] 73 s=[] 74 for root, dirs , files in os.walk(file_path): 75 d.append(root) 76 s.append(dirs) 77 for i in s: 78 if i!=s[0]: 79 s.remove(i) 80 b = [ i for item in s for i in item] 81 d.pop(0) 82 fruit=[ 'banana', 'apple', 'pear', 'grapes', 'orange','kiwi','watermelon', 'pomegranate', 'pineapple', 'mango'] 83 vegetables=['cucumber', 'carrot', 'capsicum', 'onion', 'potato', 'lemon', 'tomato', 84 'raddish', 'beetroot', 'cabbage', 'lettuce', 'spinach', 'soy beans', 85 'cauliflower', 'bell pepper', 'chilli pepper', 'turnip', 'corn', 'sweetcorn', 86 'sweetpotato', 'paprika', 'jalepeno', 'ginger', 'garlic', 'peas', 'eggplant'] 87 #训练集数据处理 88 #水果 89 for i in b: 90 for j in fruit: 91 if i==j: 92 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/test",i), 93 "C:/Users/linyicheng/Desktop/fruit/test") 94 #蔬菜 95 for i in b: 96 for j in vegetables: 97 if i==j: 98 shutil.move(os.path.join("C:/Users/linyicheng/Desktop/test",i), 99 "C:/Users/linyicheng/Desktop/vegetable/test") 100 101 test("C:/Users/linyicheng/Desktop/test") 102 #水果 103 outer_path ="C:/Users/linyicheng/Desktop/fruit/test" 104 folderlist = os.listdir(outer_path) #列举文件夹 105 106 for folder in folderlist: 107 inner_path = os.path.join(outer_path, folder) 108 total_num_folder = len(folderlist) #文件夹的总数 109 #打印文件夹的总数 110 filelist = os.listdir(inner_path) #列举图片 111 i = 0 112 for item in filelist: 113 total_num_file = len(filelist) #单个文件夹内图片的总数 114 if item.endswith('.jpg'): 115 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 116 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 117 try: 118 os.rename(src, dst) 119 i += 1 120 except: 121 continue 122 outer_path ="C:/Users/linyicheng/Desktop/fruit/train" 123 folderlist = os.listdir(outer_path) #列举文件夹 124 125 for folder in folderlist: 126 inner_path = os.path.join(outer_path, folder) 127 total_num_folder = len(folderlist) #文件夹的总数 128 #打印文件夹的总数 129 130 filelist = os.listdir(inner_path) #列举图片 131 i = 0 132 for item in filelist: 133 total_num_file = len(filelist) #单个文件夹内图片的总数 134 if item.endswith('.jpg'): 135 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 136 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 137 try: 138 os.rename(src, dst) 139 140 i += 1 141 except: 142 continue 143 outer_path ="C:/Users/linyicheng/Desktop/fruit/validation" 144 folderlist = os.listdir(outer_path) #列举文件夹 145 146 for folder in folderlist: 147 inner_path = os.path.join(outer_path, folder) 148 total_num_folder = len(folderlist) #文件夹的总数 149 #打印文件夹的总数 150 151 filelist = os.listdir(inner_path) #列举图片 152 i = 0 153 for item in filelist: 154 total_num_file = len(filelist) #单个文件夹内图片的总数 155 if item.endswith('.jpg'): 156 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 157 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 158 try: 159 os.rename(src, dst) 160 161 i += 1 162 except: 163 continue 164 #蔬菜 165 outer_path ="C:/Users/linyicheng/Desktop/vegetable/test" 166 folderlist = os.listdir(outer_path) #列举文件夹 167 168 for folder in folderlist: 169 inner_path = os.path.join(outer_path, folder) 170 total_num_folder = len(folderlist) #文件夹的总数 171 #打印文件夹的总数 172 filelist = os.listdir(inner_path) #列举图片 173 i = 0 174 for item in filelist: 175 total_num_file = len(filelist) #单个文件夹内图片的总数 176 if item.endswith('.jpg'): 177 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 178 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 179 try: 180 os.rename(src, dst) 181 i += 1 182 except: 183 continue 184 outer_path ="C:/Users/linyicheng/Desktop/vegetable/train" 185 folderlist = os.listdir(outer_path) #列举文件夹 186 187 for folder in folderlist: 188 inner_path = os.path.join(outer_path, folder) 189 total_num_folder = len(folderlist) #文件夹的总数 190 #打印文件夹的总数 191 192 filelist = os.listdir(inner_path) #列举图片 193 i = 0 194 for item in filelist: 195 total_num_file = len(filelist) #单个文件夹内图片的总数 196 if item.endswith('.jpg'): 197 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 198 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 199 try: 200 os.rename(src, dst) 201 202 i += 1 203 except: 204 continue 205 outer_path ="C:/Users/linyicheng/Desktop/vegetable/validation" 206 folderlist = os.listdir(outer_path) #列举文件夹 207 208 for folder in folderlist: 209 inner_path = os.path.join(outer_path, folder) 210 total_num_folder = len(folderlist) #文件夹的总数 211 #打印文件夹的总数 212 213 filelist = os.listdir(inner_path) #列举图片 214 i = 0 215 for item in filelist: 216 total_num_file = len(filelist) #单个文件夹内图片的总数 217 if item.endswith('.jpg'): 218 src = os.path.join(os.path.abspath(inner_path), item) #原图的地址 219 dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(i) + '.jpg') #新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称) 220 try: 221 os.rename(src, dst) 222 223 i += 1 224 except: 225 continue 226 #目标文件夹,此处为相对路径,也可以改为绝对路径 227 import shutil 228 determination = "C:/Users/linyicheng/Desktop/base/test/fruit"#目标 229 if not os.path.exists(determination): 230 os.makedirs(determination) 231 232 #源文件夹路径 233 path = "C:/Users/linyicheng/Desktop/fruit/test" 234 folders = os.listdir(path) 235 for folder in folders: 236 dir = path + '/' + str(folder) 237 files = os.listdir(dir) 238 for file in files: 239 source = dir + '/' + str(file) 240 deter = determination + '/' + str(file) 241 shutil.copyfile(source, deter) 242 determination = "C:/Users/linyicheng/Desktop/base/train/fruit" 243 if not os.path.exists(determination): 244 os.makedirs(determination) 245 path = "C:/Users/linyicheng/Desktop/fruit/train" 246 folders = os.listdir(path) 247 for folder in folders: 248 dir = path + '/' + str(folder) 249 files = os.listdir(dir) 250 for file in files: 251 source = dir + '/' + str(file) 252 deter = determination + '/' + str(file) 253 shutil.copyfile(source, deter) 254 determination = "C:/Users/linyicheng/Desktop/base/validation/fruit" 255 if not os.path.exists(determination): 256 os.makedirs(determination) 257 path = "C:/Users/linyicheng/Desktop/fruit/validation" 258 folders = os.listdir(path) 259 for folder in folders: 260 dir = path + '/' + str(folder) 261 files = os.listdir(dir) 262 for file in files: 263 source = dir + '/' + str(file) 264 deter = determination + '/' + str(file) 265 shutil.copyfile(source, deter) 266 import shutil 267 determination = "C:/Users/linyicheng/Desktop/base/test/vegetable"#目标 268 if not os.path.exists(determination): 269 os.makedirs(determination) 270 271 #源文件夹路径 272 path = "C:/Users/linyicheng/Desktop/vegetable/test" 273 folders = os.listdir(path) 274 for folder in folders: 275 dir = path + '/' + str(folder) 276 files = os.listdir(dir) 277 for file in files: 278 source = dir + '/' + str(file) 279 deter = determination + '/' + str(file) 280 shutil.copyfile(source, deter) 281 determination = "C:/Users/linyicheng/Desktop/base/train/vegetable" 282 if not os.path.exists(determination): 283 os.makedirs(determination) 284 path = "C:/Users/linyicheng/Desktop/vegetable/train" 285 folders = os.listdir(path) 286 for folder in folders: 287 dir = path + '/' + str(folder) 288 files = os.listdir(dir) 289 for file in files: 290 source = dir + '/' + str(file) 291 deter = determination + '/' + str(file) 292 shutil.copyfile(source, deter) 293 determination = "C:/Users/linyicheng/Desktop/base/validation/vegetable" 294 if not os.path.exists(determination): 295 os.makedirs(determination) 296 path = "C:/Users/linyicheng/Desktop/vegetable/validation" 297 folders = os.listdir(path) 298 for folder in folders: 299 dir = path + '/' + str(folder) 300 files = os.listdir(dir) 301 for file in files: 302 source = dir + '/' + str(file) 303 deter = determination + '/' + str(file) 304 shutil.copyfile(source, deter) 305 306 base_dir="C:/Users/linyicheng/Desktop/base/" 307 train_data_dir="C:/Users/linyicheng/Desktop/base/train/" 308 test_data_dir="C:/Users/linyicheng/Desktop/base/test/" 309 val_data_dir="C:/Users/linyicheng/Desktop/base/validation/" 310 train_fruit_data="C:/Users/linyicheng/Desktop/base/train/fruit/" 311 test_fruit_data="C:/Users/linyicheng/Desktop/base/test/fruit/" 312 test_vegetable_data="C:/Users/linyicheng/Desktop/base/test/vegetable/" 313 314 #搭建卷积神经网络 315 from keras import layers 316 from keras import models 317 model=models.Sequential() 318 """ 319 Output shape计算公式:(输入尺寸-卷积核尺寸/步长+1 320 对CNN模型,Param的计算方法如下: 321 卷积核长度*卷积核宽度*通道数+1)*卷积核个数 322 输出图片尺寸:150-3+1=148*148 323 参数数量:32*3*3*3+32=896 324 """ 325 model.add(layers.Conv2D(32,(3,3),activation='relu',input_shape=(150,150,3)))#卷积层1 326 model.add(layers.MaxPooling2D(2,2))#最大池化层1 327 # 输出图片尺寸:148/2=74*74 328 # 输出图片尺寸:74-3+1=72*72,参数数量:64*3*3*32+64=18496 329 #32是第1个卷积层的输出的通道数 330 model.add(layers.Conv2D(64, (3, 3), activation='relu'))#卷积层2 331 model.add(layers.MaxPooling2D((2, 2)))#最大池化层2 332 # 输出图片尺寸:72/2=36*36 333 334 #Output Shape的输出为36 335 # 输出图片尺寸:36-3+1=34*34,参数数量:128*3*3*64+128=73856 336 model.add(layers.Conv2D(128, (3, 3), activation='relu'))#卷积层3 337 model.add(layers.MaxPooling2D((2, 2)))#最大池化层3 338 # 输出图片尺寸:34/2=17*17 339 # 输出图片尺寸:17-3+1=15*15,参数数量:128*3*3*128+128=147584 340 model.add(layers.Conv2D(128, (3, 3), activation='relu'))# 输出图片尺寸:15/2=7*7 341 model.add(layers.MaxPooling2D((2, 2))) 342 model.add(layers.Flatten()) 343 model.add(layers.Dense(512, activation='relu')) #全连接层1 344 model.add(layers.Dense(1, activation='sigmoid'))#全连接层2,作为输出层 345 model.summary() 346 347 from keras import optimizers 348 model.compile(loss='binary_crossentropy', 349 optimizer=optimizers.RMSprop(lr=1e-4), 350 metrics=['acc']) 351 352 from keras.preprocessing.image import ImageDataGenerator 353 #归一化 354 355 train_datagen = ImageDataGenerator(rescale=1./255) 356 validation_datagen=ImageDataGenerator(rescale=1./255) 357 test_datagen = ImageDataGenerator(rescale=1./255) 358 # 数据集加载函数,指明数据集的位置并统一处理为imgheight*imgwidth的大小,同时设置batch 359 # 加载训练集 360 train_ds = train_datagen.flow_from_directory ( 361 train_data_dir, 362 class_mode='binary', 363 target_size=(150, 150), 364 batch_size=20) 365 #加载测试集 366 test_ds = test_datagen.flow_from_directory( 367 test_data_dir, 368 class_mode='binary', 369 target_size=(150, 150), 370 batch_size=20) 371 # 加载验证集 372 val_ds =test_datagen.flow_from_directory( 373 val_data_dir, 374 class_mode='binary', 375 target_size=(150, 150), 376 batch_size=20) 377 378 for data_batch, labels_batch in train_ds: 379 print('data batch shape:', data_batch.shape) 380 print('labels batch shape:', labels_batch.shape) 381 break 382 383 #训练模型30轮次,可以修改epochs的值 384 history = model.fit_generator( 385 train_ds, 386 steps_per_epoch=100, 387 epochs=30, 388 validation_data=val_ds, 389 validation_steps=50) 390 391 #将训练过程产生的数据保存为h5文件 392 model.save('fruit_and_vegetable_30epoch.h5') 393 394 import matplotlib.pyplot as plt 395 396 acc = history.history['acc'] 397 val_acc = history.history['val_acc'] 398 loss = history.history['loss'] 399 val_loss = history.history['val_loss'] 400 401 epochs = range(1, len(acc) + 1) 402 403 plt.plot(epochs, acc, 'bo', label='Training acc') 404 plt.plot(epochs, val_acc, 'b', label='Validation acc') 405 plt.title('Training and validation accuracy') 406 plt.legend() 407 plt.figure() 408 plt.plot(epochs, loss, 'bo', label='Training loss') 409 plt.plot(epochs, val_loss, 'b', label='Validation loss') 410 plt.title('Training and validation loss') 411 plt.legend() 412 413 plt.show() 414 415 from PIL import Image 416 from keras.preprocessing import image 417 from keras.models import load_model 418 import numpy as np 419 #加载模型 420 model=load_model('fruit_and_vegetable_30epoch.h5') 421 #本地图片路径 422 img_path=("C:/Users/linyicheng/Desktop/base/train/fruit/apple_41.jpg") 423 img = image.load_img(img_path,target_size=(150,150)) 424 # 将其转换为具有形状的Numpy数组(150、150、3) 425 img_tensor = image.img_to_array(img)/255 426 # 将其形态变为(150,150,150,3)的形状 427 img_tensor = np.expand_dims(img_tensor, axis=0) 428 #取图片信息 429 prediction =model.predict(img_tensor) 430 #输出识别率 431 print(prediction) 432 if prediction>0.5: 433 print('水果') 434 else: 435 print('蔬菜') 436 plt.imshow(img) 437 438 from keras.preprocessing.image import ImageDataGenerator 439 from keras.models import load_model 440 from tensorflow.python.keras.preprocessing import image 441 from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img 442 443 import os 444 from keras.preprocessing.image import load_img 445 import matplotlib.pyplot as plt 446 figure,ax = plt.subplots(nrows=1, ncols=4, sharex=True, 447 sharey=True,figsize=(16,10)) 448 449 datagen=ImageDataGenerator( 450 rotation_range=40, #图像随机旋转角度范围 451 width_shift_range=0.2, #图片在水平方向上平移的比例 452 height_shift_range=0.2, #图片在垂直方向上平移的比例 453 shear_range=0.2, #随机错切变换的角度 454 zoom_range=0.2, #图像随机缩放的范围 455 horizontal_flip=True, #随机将一半图像水平翻转 456 457 fill_mode='nearest') 458 #填充创建像素的一种方法 459 fnames=[os.path.join(train_fruit_data,fname)for fname in os.listdir(train_fruit_data)] 460 img_path=fnames[6] 461 img=image.load_img(img_path,target_size=(150,150)) 462 x = image.img_to_array(img)/255 463 x = x.reshape((1,) + x.shape) 464 i = 0 465 for batch in datagen.flow(x, batch_size=1): 466 ax = ax.flatten() # 将子图从多维变成一维 467 plt.figure(i) 468 imgplot = ax[i-1].imshow(image.array_to_img(batch[0]), cmap='Greys', interpolation='nearest') 469 i += 1 470 if i % 4 == 0: 471 break 472 plt.show() #绘制增强后的图像 473 ''' 474 虽然使用了数据增强技术,但是从输出的图像来看,图片之间还是有很大的相似度, 475 因为它们均来自同一张原始图片,并没有提供新的信息。 476 为了尽可能消除过拟合,可以在模型中增加一个DDropout层,添加到密集连接分类 477 器前 478 ''' 479 #在紧密连接的分类器之前为模型添加一个Dropout层 480 model = models.Sequential() 481 model.add(layers.Conv2D(32, (3, 3), activation='relu', 482 input_shape=(150, 150, 3))) 483 model.add(layers.MaxPooling2D((2, 2))) 484 model.add(layers.Conv2D(64, (3, 3), activation='relu')) 485 model.add(layers.MaxPooling2D((2, 2))) 486 model.add(layers.Conv2D(128, (3, 3), activation='relu')) 487 model.add(layers.MaxPooling2D((2, 2))) 488 model.add(layers.Conv2D(128, (3, 3), activation='relu')) 489 model.add(layers.MaxPooling2D((2, 2))) 490 model.add(layers.Flatten()) 491 model.add(layers.Dropout(0.5)) 492 model.add(layers.Dense(512, activation='relu')) 493 model.add(layers.Dense(1, activation='sigmoid')) 494 495 from keras import optimizers 496 model.compile(loss='binary_crossentropy', 497 optimizer=optimizers.RMSprop(lr=1e-4), 498 metrics=['acc']) 499 500 #使用数据增强和dropout来训练我们的网络 501 #归一化 502 train_datagen = ImageDataGenerator( 503 rescale=1./255, 504 rotation_range=40, 505 width_shift_range=0.2, 506 height_shift_range=0.2, 507 shear_range=0.2, 508 zoom_range=0.2, 509 horizontal_flip=True) 510 test_datagen = ImageDataGenerator(rescale=1./255) 511 train_generator = train_datagen.flow_from_directory( 512 train_data_dir, 513 #输入训练图像尺寸,所有图片的size必须是150x150 514 target_size=(150, 150), 515 batch_size=32, 516 #因为我们使用二元交叉熵损失,我们需要二元标签 517 class_mode='binary') 518 validation_generator = test_datagen.flow_from_directory( 519 val_data_dir, 520 target_size=(150, 150), 521 batch_size=32, 522 class_mode='binary') 523 524 history = model.fit_generator( 525 train_generator, 526 steps_per_epoch=100, 527 epochs=100, 528 validation_data=validation_generator, 529 validation_steps=50) 530 531 model.save('fruit_and_vegetable_100epoch.h5') 532 533 import matplotlib.pyplot as plt 534 535 acc = history.history['acc'] 536 val_acc = history.history['val_acc'] 537 loss = history.history['loss'] 538 val_loss = history.history['val_loss'] 539 540 epochs = range(1, len(acc) + 1) 541 542 plt.plot(epochs, acc, 'bo', label='Training acc') 543 plt.plot(epochs, val_acc, 'b', label='Validation acc') 544 plt.title('Training and validation accuracy') 545 plt.legend() 546 plt.figure() 547 plt.plot(epochs, loss, 'bo', label='Training loss') 548 plt.plot(epochs, val_loss, 'b', label='Validation loss') 549 plt.title('Training and validation loss') 550 plt.legend() 551 552 plt.show() 553 554 from keras.preprocessing import image 555 556 from keras.models import load_model 557 import numpy as np 558 #加载模型 559 model=load_model('fruit_and_vegetable_100epoch.h5') 560 #本地图片路径 561 img_path=("C:/Users/linyicheng/Desktop/base/train/fruit/apple_42.jpg") 562 img = image.load_img(img_path,target_size=(150,150)) 563 # 将其转换为具有形状的Numpy数组(150、150、3) 564 img_tensor = image.img_to_array(img)/255 565 # 将其形态变为(150,150,150,3)的形状 566 img_tensor = np.expand_dims(img_tensor, axis=0) 567 #取图片信息 568 prediction =model.predict(img_tensor) 569 #输出识别率 570 print(prediction) 571 if prediction>0.5: 572 print('水果') 573 else: 574 print('蔬菜') 575 plt.imshow(img) 576 577 from PIL import Image 578 from keras.preprocessing import image 579 import matplotlib.pyplot as plt 580 from keras.models import load_model 581 plt.rcParams['font.family'] = ['SimHei'] 582 figure,ax = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True,figsize=(24,10)) 583 def convertjpg(jpgfile): 584 #将图片缩小到(150,150)的大小 585 img=Image.open(jpgfile) 586 try: 587 new_img=img.resize((150,150),Image.BILINEAR) 588 return new_img 589 except Exception as e: 590 print(e) 591 snames=[os.path.join(test_vegetable_data,sname)for sname in os.listdir(test_vegetable_data)] 592 j=0 593 for i in snames[0:8] : 594 j+=1 595 img_show=convertjpg(i) 596 img_scale = image.img_to_array(img_show) 597 img_scale = img_scale.reshape(1,150,150,3) # 将形状转化为(1,150,150,3) 598 img_scale = img_scale.astype('float32')/255 599 result = model.predict(img_scale)# 预测函数- 600 if result>0.5: 601 title = '蔬菜' 602 label = '正确识别率为:'+str(result) 603 else: 604 title = '水果' 605 label = '正确识别率为:'+str(1-result) 606 607 ax = ax.flatten() # 将子图从多维变成一维 608 ax[j-1].imshow(img_show, cmap='Greys', interpolation='nearest') 609 # 子图标题 610 ax[j-1].set_title(title,fontsize=24) 611 # 子图X轴标签 612 ax[j-1].set_xlabel(label,fontsize=24) 613 # 去掉刻度 614 ax[0].set_xticks([]) 615 ax[0].set_yticks([]) 616 plt.show() 617 618 figure,ax = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True,figsize=(24,10)) 619 620 fnames=[os.path.join(test_fruit_data,fname)for fname in os.listdir(test_fruit_data)] 621 j=0 622 for i in fnames[0:8]: 623 j+=1 624 img_show=convertjpg(i) 625 img_scale = image.img_to_array(img_show) 626 img_scale = img_scale.reshape(1,150,150,3) # 将形状转化为(1,150,150,3) 627 img_scale = img_scale.astype('float32')/255 628 result = model.predict(img_scale)# 预测函数- 629 if result>0.5: 630 title = '蔬菜' 631 label = '正确识别率为:'+ str(result) 632 else: 633 title = '水果' 634 label = '正确识别率为:'+str(1-result) 635 ax = ax.flatten() # 将子图从多维变成一维 636 ax[j-1].imshow(img_show, cmap='Greys', interpolation='nearest') 637 # 子图标题 638 ax[j-1].set_title(title,fontsize=24) 639 # 子图X轴标签 640 ax[j-1].set_xlabel(label,fontsize=24) 641 # 去掉刻度 642 ax[0].set_xticks([]) 643 ax[0].set_yticks([])
四、总结
此次机器学习过程中参考了minist案例和猫狗识别案例,在数据处理构建数据集目录是出现一些问题,开始思路没有构建好,导致数据处理花费太多时间,以及代码太过冗长,还有就是在导入一些必要的数据库是出现问题如PIL,在网上查明原因说需要更新pillow和PIL的版本,后续应多加练习熟悉所需的各种环境。
这次机器学习中,无论是第一次训练模型还是增强后的模型训练都达到了预期的效果,但是在进行机器学习设计的过程中发现自己对程序设计的思路没有理清楚,导致花费了大量时间改正,以后在进行类似学习前应当将思路理清,再付诸行动。