Pytorch实战-CIFAR-10图像分类
Pytorch实战-CIFAR-10图像分类
近年来,卷积神经网络(CNN)在图像分类领域取得了显著成效。本文将以 CIFAR-10 数据集为例,详细讲解如何使用 PyTorch 构建一个完整的图像分类项目。涵盖数据处理、模型构建、训练与验证、性能评估及可视化分析。
一、CIFAR-10 数据集简介
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。
数据集形式:
文件名 | 文件类型 | 大小 | 用途 |
---|---|---|---|
readme.html | HTML文件 | 1 KB | 数据集介绍文件 |
batches.meta | META文件 | 1 KB | 存储了每个类别的英文名称,可用记事本或其他文本文件阅读器打开查看 |
data_batch_1 | 二进制文件 | 30,309 KB | 训练数据第1批,包含10000张32×32彩色图像及其类别标签 |
data_batch_2 | 二进制文件 | 30,308 KB | 训练数据第2批,包含10000张32×32彩色图像及其类别标签 |
data_batch_3 | 二进制文件 | 30,309 KB | 训练数据第3批,包含10000张32×32彩色图像及其类别标签 |
data_batch_4 | 二进制文件 | 30,309 KB | 训练数据第4批,包含10000张32×32彩色图像及其类别标签 |
data_batch_5 | 二进制文件 | 30,309 KB | 训练数据第5批,包含10000张32×32彩色图像及其类别标签 |
test_batch | 二进制文件 | 30,309 KB | 测试数据,包含10000张测试图像及其类别标签 |
数据集特点:
属性 | 值 |
---|---|
图像数量 | 60,000 张 |
训练集大小 | 50,000 张 |
测试集大小 | 10,000 张 |
图像尺寸 | 32×32 像素 |
通道数 | 3(RGB) |
分类数量 | 10 类 |
每类样本数量 | 6,000 张 |
标签格式 | 整数(0~9) |
常用用途 | 图像分类、迁移学习、图像增强实验 |
具体类别:
类别编号 | 类别名称(英文) | 含义说明 |
---|---|---|
0 | airplane | 飞机 |
1 | automobile | 汽车 |
2 | bird | 鸟 |
3 | cat | 猫 |
4 | deer | 鹿 |
5 | dog | 狗 |
6 | frog | 青蛙 |
7 | horse | 马 |
8 | ship | 船 |
9 | truck | 卡车(货车) |
CIFAR-10数据集下载:
- Pytorch:
torchvision.datasets.CIFAR10(root='./dataset', train=True/False, download=True)
- 官网下载:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
- Github:https://github.com/wmt319/CIFAR-10_CNN/dataset
二、环境配置与依赖
# 创建虚拟环境并安装依赖
pip install torch torchvision matplotlib numpy pandas scikit-learn
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
提示:推荐在 GPU 环境下执行(CUDA 可用时优先使用)。
三、数据预处理与可视化
在训练 CNN 模型之前,数据预处理是至关重要的一步。合理的归一化和标准化可以使得模型更快收敛、性能更稳定;而数据可视化不仅能帮助我们直观了解数据分布,还能及时发现异常样本。
3.1 归一化与标准化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 全局均值与方差
])
- ToTensor():将 PIL 图像或 NumPy 数组转换为形状为 [C, H, W] 的浮点型张量,并将像素值归一化到 [0,1] 区间;
- Normalize():对每个通道执行 (x - mean) / std,这些值是根据整个训练集计算得到的,全局标准化有助于梯度更稳定。
Tip:如果你使用自己的数据集,需要先利用 torch.mean 和 torch.std 统计均值与方差,或在可视化工具中直观估计数据分布。
3.2 数据加载
train_dataset = datasets.CIAFR10(root='./dataset', train=True, download=False, transform=transform)
test_dataset = datasets.CIFAR10(root='./dataset', train=False, download=False, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False,num_workers=0)
- root:指定本地数据存储路径。例如 ./dataset 文件夹下会生成 MNIST 子目录;
- train:布尔值,True 表示加载训练集,False 表示加载测试集;
- download:布尔值,若本地 root 路径下不存在数据,则自动从网络下载(可能需要魔法上网);若已下载可设为 False;
- shuffle:训练时打乱数据顺序,有助于打破样本间关联性,提升泛化;
- batch_size:每批次读取样本数,平衡训练稳定性和显存使用;
- num_workers:用于数据加载的子进程数量,可加速 I/O,但需与系统资源匹配。
注意:在 Windows 下使用 num_workers > 0 有时会导致多进程启动延迟,可设置为 0 避免兼容性问题。
3.3 样本可视化
在正式训练前,通过可视化部分样本,可以帮助我们确认标签是否正确、图像是否有噪声或畸变,以及预处理是否生效。
# 展示 10 个样本
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
data_iter = iter(train_loader)
images, labels = next(data_iter)
plt.figure(figsize=(12, 6))
for i in range(10):
img = (images[i] * 0.5 + 0.5) # 反归一化
img = img.permute(1, 2, 0).numpy() # CHW -> HWC
plt.subplot(2, 5, i+1)
plt.imshow(img)
plt.title(class_names[labels[i].item()])
plt.axis('off')
plt.suptitle("CIFAR-10", fontsize=16)
plt.tight_layout()
plt.show()
四、CNN 模型架构与实现
class CIFAR10_CNN(nn.Module):
def __init__(self):
super(CIFAR10_CNN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 6,5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
def forward(self, x):
x = self.net(x)
return x
# 模型实例化与设备部署
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CIFAR10_CNN().to(device)
# 使用 torchinfo 输出模型摘要
from torchinfo import summary
summary(model, (64, 3, 32, 32))
-
卷积层 conv1:
nn.Conv2d(3, 6, 5)
:输入为 3 通道彩色图像,使用 6 个 5×5 的卷积核进行特征提取;nn.ReLU()
:添加非线性激活函数,引入模型表达能力;nn.MaxPool2d(2, 2)
:使用 2×2 的池化窗口下采样,将图像尺寸从 32×32 缩小为14×14。
📐 卷积输出尺寸计算(不使用 padding,stride=1):
- Conv1 输出:$(32 - 5 + 1) = 28 \Rightarrow 28 × 28 × 6$
- Pool1 输出:$28 / 2 = 14 \Rightarrow 14 × 14 × 6$
-
卷积层 conv2:
nn.Conv2d(6, 16, 5)
:将前一层的 6 通道输出扩展为 16 个高层特征通道;nn.ReLU()
:继续添加非线性;nn.MaxPool2d(2, 2)
:再一次下采样,进一步降低空间维度。
📐 尺寸变化:
- Conv2 输出:$(14 - 5 + 1) = 10 \Rightarrow 10 × 10 × 16$
- Pool2 输出:$10 / 2 = 5 \Rightarrow 5 × 5 × 16$
-
全连接模块 fc:
nn.Flatten()
:将 16 × 5 × 5 的三维张量展平成 400 维向量;nn.Linear(400, 120)
:第一层全连接层,映射到 120 维;nn.ReLU()
:继续使用 ReLU 激活;nn.Linear(120, 84)
:第二层全连接层,进一步学习紧凑表示;nn.Linear(84, 10)
:输出 10 维向量,每一维对应一个 CIFAR-10 类别的得分。
-
模型摘要:
summary(model, input_size=(64,1,28,28))
可直观打印各层输出形状与参数量,便于调试与性能评估;
📋 模型各层形状与参数如下:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
CIFAR10_CNN [64, 10] --
├─Sequential: 1-1 [64, 6, 14, 14] --
│ └─Conv2d: 2-1 [64, 6, 28, 28] 456
│ └─ReLU: 2-2 [64, 6, 28, 28] --
│ └─MaxPool2d: 2-3 [64, 6, 14, 14] --
├─Sequential: 1-2 [64, 16, 5, 5] --
│ └─Conv2d: 2-4 [64, 16, 10, 10] 2,416
│ └─ReLU: 2-5 [64, 16, 10, 10] --
│ └─MaxPool2d: 2-6 [64, 16, 5, 5] --
├─Sequential: 1-3 [64, 10] --
│ └─Flatten: 2-7 [64, 400] --
│ └─Linear: 2-8 [64, 120] 48,120
│ └─ReLU: 2-9 [64, 120] --
│ └─Linear: 2-10 [64, 84] 10,164
│ └─ReLU: 2-11 [64, 84] --
│ └─Linear: 2-12 [64, 10] 850
==========================================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
Total mult-adds (M): 42.13
==========================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 3.34
Params size (MB): 0.25
Estimated Total Size (MB): 4.37
==========================================================================================
五、训练策略与超参数
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
- 损失函数:交叉熵损失是多分类任务的标准选择,可衡量模型输出的概率分布与真实标签分布之间的差异;
- 优化器:SGD具有良好的稳定性和可解释性;
- 学习率调度:每 20 个 epoch 衰减 10 倍;
- 动量:momentum=0.9 能加速收敛并抑制震荡。
六、训练与评估
6.1 训练循环
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
for i in range(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
i, batch_idx * len(inputs), len(train_loader.dataset),
100. * batch_idx / len(train_loader), running_loss / 100))
running_loss = 0.0
if i % 10 == 0:
torch.save(model.state_dict(), './checkpoints/model_{}.pth'.format(i))
print('Model saved as model_{}.pth'.format(i))
train(model, train_loader, criterion, optimizer, 100)
6.2 精确率
def test(model, test_loader):
model.eval()
class_correct = list(0. for _ in range(10))
class_total = list(0. for _ in range(10))
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(len(images)):
label = labels[i].item() # 转换为 Python 整数
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
if class_total[i] > 0:
print(f'Accuracy of {classes[i]:5s} : {100 * class_correct[i]/class_total[i]:2.0f}%')
else:
print(f'Accuracy of {classes[i]:5s} : N/A')
test(model, test_loader)
Accuracy of plane : 70%
Accuracy of car : 76%
Accuracy of bird : 45%
Accuracy of cat : 46%
Accuracy of deer : 44%
Accuracy of dog : 48%
Accuracy of frog : 81%
Accuracy of horse : 69%
Accuracy of ship : 79%
Accuracy of truck : 69%
6.3 混淆矩阵与分类报告
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
# 1) 在 test 函数里收集
y_trues, y_preds = [], []
model.eval()
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
logits = model(x)
preds = logits.argmax(dim=1).cpu()
y_preds.extend(preds.numpy())
y_trues.extend(y.numpy())
# 2) 计算并打印报告
print(classification_report(y_trues, y_preds, target_names=classes))
# 3) 混淆矩阵
cm = confusion_matrix(y_trues, y_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=classes)
disp.plot(cmap=plt.cm.Blues)
plt.xticks(rotation=45)
plt.title("Confusion Matrix")
plt.show()
precision recall f1-score support
plane 0.67 0.70 0.69 1000
car 0.72 0.76 0.74 1000
bird 0.60 0.45 0.52 1000
cat 0.42 0.46 0.44 1000
deer 0.62 0.44 0.52 1000
dog 0.55 0.48 0.51 1000
frog 0.60 0.81 0.69 1000
horse 0.70 0.69 0.70 1000
ship 0.70 0.79 0.74 1000
truck 0.68 0.69 0.68 1000
accuracy 0.63 10000
macro avg 0.63 0.63 0.62 10000
weighted avg 0.63 0.63 0.62 10000
6.4 ROC曲线图
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
# 假设有 10 类
n_classes = 10
y_trues_bin = label_binarize(y_trues, classes=list(range(n_classes)))
y_scores = [] # 用 softmax 后的概率
model.eval()
with torch.no_grad():
for x, _ in test_loader:
x = x.to(device)
logits = model(x)
probs = torch.softmax(logits, dim=1).cpu().numpy()
y_scores.append(probs)
y_scores = np.vstack(y_scores)
# 针对每一类画 ROC
plt.figure()
for i in range(n_classes):
fpr, tpr, _ = roc_curve(y_trues_bin[:, i], y_scores[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'Class {classes[i]} (AUC = {roc_auc:.2f})')
plt.plot([0,1], [0,1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (one-vs-rest)')
plt.legend(loc='lower right')
plt.show()
📌 解释:
- 采用 One-vs-Rest 方法,对每个类分别绘制 ROC 曲线。
AUC
越高,说明该类的分类能力越强。
七、完整代码
完整代码可以在我的GitHub仓库查看:https://github.com/wmt319/CIIFAR-10_CNN
💻 提示:如果需要数据集可私信或者评论!!!
如有问题,欢迎在评论区留言讨论!