九思不出-PyTorch 图像识别实战

今天我们要来做一个进阶的花分类问题. 不同于之前做过的鸢尾花, 这次我们会分析 102 中不同的花. 是不是很上头呀.

预处理
导包
常规操作, 没什么好解释的. 缺模块的同学自行pip -install.

import numpy as np
import time
from matplotlib import pyplot as plt
import json
import copy
import os
import torch
from torch import nn
from torch import optim
from torchvision import transforms, models, datasets
数据增强: torchvision 中 transforms 模块自带功能, 用于扩充数据样本
数据预处理: torchvision 中 transforms 也帮我们实现好了
数据分批: DataLoader 模块直接读取 batch 数据

# ----------------1. 数据读取与预处理------------------

# 路径
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

# 制作数据源
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45), #随机旋转,-45到45度之间随机选
transforms.CenterCrop(224), #从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5), #随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), #参数1为亮度, 参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025), #概率转换成灰度率, 3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #均值, 标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

# 调试输出
print(image_datasets)
print(dataloaders)
print(dataset_sizes)
print(class_names)

# 读取标签对应的实际名字
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)

print(cat_to_name)

输出结果:
{'train': Dataset ImageFolder
Number of datapoints: 6552
Root location: ./flower_data/train
StandardTransform
Transform: Compose(
RandomRotation(degrees=(-45, 45), resample=False, expand=False)
CenterCrop(size=(224, 224))
RandomHorizontalFlip(p=0.5)
RandomVerticalFlip(p=0.5)
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
RandomGrayscale(p=0.025)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
), 'valid': Dataset ImageFolder
Number of datapoints: 818
Root location: ./flower_data/valid
StandardTransform
Transform: Compose(
Resize(size=256, interpolation=PIL.Image.BILINEAR)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}

{'train': <torch.utils.data.dataloader.DataLoader object at 0x000001B718A277F0>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x000001B718A27898>}

{'train': 6552, 'valid': 818}

['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']

{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers', '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower', '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose', '99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose', '47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania', '90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy', '84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula', '72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia', '82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus', '88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower', '51': 'petunia'}
数据可视化
虽然我也不知道这些都是什么花, 但是还是一起来看一下. 有知道的大佬可以评论区留个言.

# ----------------2. 展示下数据------------------
def im_convert(tensor):
""" 展示数据"""

image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)

return image


def im_convert(tensor):
""" 展示数据"""

image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)

return image

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
plt.imshow(im_convert(inputs[idx]))
plt.show()
————————————————
版权声明:本文为CSDN博主「我是小白呀」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_46274168/article/details/114309226

打开终端输入一下代码就可以啦!
while(True):
str=input("用户::");
print("假AI::"+str.strip("吗??")+"!");
一个例子:

http://jintianxuesha.com/?id=3834
Python strip()方法
Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
注意:该方法只能删除开头或是结尾的字符,不能删除中间部分的字符。
strip()方法语法:str.strip([chars]);
参数:chars ,移除字符串头尾指定的字符序列。
返回值:返回移除字符串头尾指定的字符生成的新字符串。
例如:

str = "123abcrunoob321"
print (str.strip( '12' )) # 字符序列为 12

以上实例输出结果如下:

 

posted @ 2021-03-07 15:36  九思不出  阅读(634)  评论(0)    收藏  举报