PyTorch图像分类全流程实战--训练得到的模型预测图像04

教程

  1. 同济子豪兄 https://space.bilibili.com/1900783
  2. https://www.bilibili.com/video/BV1qe4y1D7zD
  3. Github:Train_Custom_Dataset/图像分类/4

配置环境

数据处理:numpy pandas

可视化:matplotlib

HTTP请求:requests

进度条:tqdm

图像处理:opencv-python pillow(PIL)

Python Pillow 官方文档:https://pillow.readthedocs.io/en/latest/
Pillow 库提供了非常丰富的功能【1】,主要有以下几点:
Pillow 库能够很轻松的读取和保存各种格式的图片;
Pillow 库提供了简洁易用的 API 接口,可以让您轻松地完成许多图像处理任务;
Pillow 库能够配合 GUI(图形用户界面) 软件包 Tkinter 一起使用;
Pillow 库中的 Image 对象能够与 NumPy ndarray 数组实现相互转换。

Pytorch工具包:torch torchvision torchaudio

计算机视觉的基础库:mmcv-full(本教程中主要为了处理视频)【3】

实验材料

需要有测试图片和视频(mp4格式);

需要保存:训练结果、训练得到的模型权重。

预测新图像

import torch
import torchvision
import torch.nn.functional as F

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


from PIL import Image, ImageFont, ImageDraw
# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 32)

#导入训练好的模型
model = torch.load('checkpoints/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)
#预处理
from torchvision import transforms
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])
img_path = 'test_img/watermelon1.jpg'
img_pil = Image.open(img_path)
input_img = test_transform(img_pil) # 预处理
input_img = input_img.unsqueeze(0).to(device)
# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img) 
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

unsqueeze(0)

画图

plt.figure(figsize=(22, 10))

x = idx_to_labels.values()
y = pred_softmax.cpu().detach().numpy()[0] * 100
width = 0.45 # 柱状图宽度

ax = plt.bar(x, y, width)

plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度数值
plt.tick_params(labelsize=20) # 设置坐标文字大小

plt.title(img_path, fontsize=30)
plt.xticks(rotation=45) # 横轴文字旋转
plt.xlabel('类别', fontsize=20)
plt.ylabel('置信度', fontsize=20)
plt.show()

置信度最大的前 n 个结果

n = 10
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度

.cpu().detach().numpy().squeeze()
.cpu()将数据移至CPU中;
.detach()作用:阻断反向传播的;
.numpy()将cpu上的tensor转为numpy数据;【4】
.squeeze()从数组的形状中删除单维度条目,即把shape中为1的维度去掉.【5】

torch.topk:取一个tensor的topk元素(降序后的前k个大小的元素值及索引)【6】

#图像分类结果写在原图上
draw = ImageDraw.Draw(img_pil)
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]] # 获取类别名称
    confidence = confs[i] * 100 # 获取置信度
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)
    
    # 文字坐标,中文字符串,字体,rgba颜色
    draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
fig = plt.figure(figsize=(18,6))

# 绘制左图-预测图
ax1 = plt.subplot(1,2,1)
ax1.imshow(img_pil)
ax1.axis('off')

# 绘制右图-柱状图
ax2 = plt.subplot(1,2,2)
x = idx_to_labels.values()
y = pred_softmax.cpu().detach().numpy()[0] * 100
ax2.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
plt.bar_label(ax, fmt='%.2f', fontsize=10) # 置信度数值

plt.title('{} 图像分类预测结果'.format(img_path), fontsize=30)
plt.xlabel('类别', fontsize=20)
plt.ylabel('置信度', fontsize=20)
plt.ylim([0, 110]) # y轴取值范围
ax2.tick_params(labelsize=16) # 坐标文字大小
plt.xticks(rotation=90) # 横轴文字旋转

plt.tight_layout()
fig.savefig('output/预测图+柱状图.jpg')
#预测结果输出
pred_df = pd.DataFrame() # 预测结果表格
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]] # 获取类别名称
    label_idx = int(pred_ids[i]) # 获取类别号
    confidence = confs[i] * 100 # 获取置信度
    pred_df = pred_df.append({'Class':class_name, 'Class_ID':label_idx, 'Confidence(%)':confidence}, ignore_index=True) # 预测结果表格添加一行
display(pred_df) # 展示预测结果表格

参考文献

【1】Pillow(PIL)入门教程(非常详细)
【2】Python Pillow 官方文档
【3】介绍 MMCV
【4】PyTorch关于以下方法使用:detach() cpu() numpy() 以及item()
【5】Numpy库学习—squeeze()函数
【6】PyTorch torch.topk() 函数详解

posted on 2023-01-27 00:02  琢磨亿下  阅读(132)  评论(0编辑  收藏  举报

导航