花朵识别

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from tensorflow.keras import Model
from tensorflow.keras.layers import MaxPool2D, BatchNormalization,Dropout,Activation,Conv2D,MaxPool2D,Flatten
# 学习地方 1 使用网上压缩数据集,下载解压
tf.keras.utils.get_file(
    fname, origin, untar=False, md5_hash=None, 
    file_hash=None,cache_subdir='datasets', 
    hash_algorithm='auto', extract=False,
    archive_format='auto', cache_dir=None
)
'''
参数说明--
fname:文件名,如果指定了绝对路径"/path/to/file.txt",则文件将会保存到该位置 选填

origin:文件的URL

untar:boolean,文件是否需要解压缩

md5_hash:MD5哈希值,用于数据校验,支持sha256和md5哈希

cache_subdir:用于缓存数据的文件夹,若指定绝对路径"/path/to/folder"则将存放在该路径下

hash_algorithm:选择文件校验的哈希算法,可选项有'md5', 'sha256', 和'auto'. 
默认'auto'自动检测使用的哈希算法

extract:若为True则试图提取文件,例如tar或zip 

archive_format:试图提取的文件格式,可选为'auto', 'tar', 'zip', 
和None. 'tar' 包括tar, tar.gz, tar.bz文件. 默认'auto'是['tar', 'zip']. None或空列表将返回没有匹配

cache_dir:文件缓存后的地址,若为None,则默认存放在根目录的.keras文件夹中

'''
# 下载数据
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

data_dir = tf.keras.utils.get_file(fname    = 'flower_photos', # 下载到本地后的文件名称
                                   origin   = dataset_url,     # 数据集(Dataset)的URL路径;
                                   untar    = True,            # 是否解压文件
                                   cache_dir= '/content/drive/MyDrive/DL/DL 100例/花朵数据')


data_dirs='/content/drive/MyDrive/DL/DL 100例/花朵数据/datasets/flower_photos'
import pathlib
import PIL
data_dir = pathlib.Path(data_dirs)
data_dir
PosixPath('/content/drive/MyDrive/DL/DL 100例/花朵数据/datasets/flower_photos')
#计算文件夹中全部jpg的数量
len(list(data_dir.glob('*/*.jpg')))
3670
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))

output_10_0

# 数据预处理
# 使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中
batch_size = 32
img_height = 180
img_width = 180
train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset='training',
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size

)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
class_names=train_ds.class_names
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

output_15_0

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(32, 180, 180, 3)
(32,)
# 配置数据集
# shuffle():打乱数据
# prefetch()预取数据,加速运行
# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

num_classes = 5

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995
"""

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   # 池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(num_classes)               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 180, 180, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 178, 178, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 89, 89, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 87, 87, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 43, 43, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 41, 41, 64)        18496     
_________________________________________________________________
flatten (Flatten)            (None, 107584)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               13770880  
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 645       
=================================================================
Total params: 13,795,109
Trainable params: 13,795,109
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=3
)

Epoch 1/3
92/92 [==============================] - 375s 2s/step - loss: 1.5899 - accuracy: 0.3607 - val_loss: 1.1485 - val_accuracy: 0.5041
Epoch 2/3
92/92 [==============================] - 94s 1s/step - loss: 1.0672 - accuracy: 0.5698 - val_loss: 1.0633 - val_accuracy: 0.5736
Epoch 3/3
92/92 [==============================] - 92s 997ms/step - loss: 0.8770 - accuracy: 0.6614 - val_loss: 0.9982 - val_accuracy: 0.6185

posted @ 2021-07-15 09:52  符号2020  阅读(213)  评论(0)    收藏  举报