PyTorch 实例 - 猫狗分类项目
猫狗分类项目(含批量预测功能)
入门计算机视觉的经典项目——猫狗图片分类,并且重点实现批量预测文件夹内图片的功能。这个项目适合刚接触PyTorch和深度学习的小伙伴,通过实操能快速掌握数据预处理、迁移学习、模型训练与批量预测的核心流程。
一、项目介绍
本项目的核心目标是:训练一个能区分猫和狗的图像分类模型,并实现对指定文件夹(含cats、dogs子文件夹)内所有图片的批量预测,同时统计整体预测准确率。
技术亮点:
-
使用ResNet18预训练模型(迁移学习),无需从零训练,兼顾效率和准确率;
-
修复训练过程中准确率计算的常见bug,确保训练指标准确;
-
批量预测支持自动遍历子文件夹,包含异常处理(避免单张图片损坏导致程序中断);
-
输出详细预测结果(单张图片预测情况、整体准确率、错误图片列表),便于排查问题。
二、环境准备
首先需要配置好Python和相关依赖库,建议使用Anaconda创建虚拟环境,避免版本冲突。
2.1 安装依赖库
打开终端/命令提示符,执行以下命令安装所需库:
# 安装PyTorch和torchvision(根据自己的CUDA版本选择,无GPU则安装CPU版本)
# 官网地址:https://pytorch.org/get-started/locally/
pip install torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple #使用清华镜像源下载
# 安装其他依赖
pip install pillow # 处理图片
pip install numpy # 可选,用于后续拓展
2.2 验证环境
运行以下代码,若没有报错则环境配置成功:
import torch
import torchvision
from PIL import Image
print("PyTorch版本:", torch.__version__)
print("Torchvision版本:", torchvision.__version__)
三、数据准备
本项目需要训练集和测试集,目录结构有严格要求(便于代码自动识别类别),请大家按以下结构准备。
3.1 数据目录结构
catdog/ # 项目根目录
├─ data/ # 数据目录
│ ├─ train/ # 训练集目录(用于训练模型)
│ │ ├─ cats/ # 猫的训练图片(若干张.jpg/.png)
│ │ └─ dogs/ # 狗的训练图片(若干张.jpg/.png)
│ └─ test/ # 测试集目录(用于批量预测)
│ ├─ cats/ # 猫的测试图片(若干张.jpg/.png)
│ └─ dogs/ # 狗的测试图片(若干张.jpg/.png)
└─ cat_dog_model.pth # 后续训练好的模型文件(自动生成)
3.2 获取数据集
推荐使用Kaggle的经典猫狗数据集(含25000张训练图、12500张测试图),下载地址:Dogs vs. Cats(需要注册Kaggle账号)。
如果只是测试功能,也可以自己找几十张猫和狗的图片,按上述目录结构存放即可。
四、核心代码解析
完整代码已优化并修复bug,我们分模块逐行解析,理解每个部分的作用。
4.1 导入所需库
import torch
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from PIL import Image
from torchvision.models import ResNet18_Weights
import os # 用于遍历文件夹、处理文件路径
关键库说明:
-
torch/torchvision:核心深度学习框架,提供模型、数据加载、图像预处理工具;
-
PIL.Image:处理图片读取、格式转换;
-
os:用于批量遍历文件夹内的图片文件。
4.2 图像预处理(统一格式)
# 定义图像转换流程:统一大小→转为张量→标准化
transform = transforms.Compose([
transforms.Resize((128, 128)), # 所有图片缩放到128×128(ResNet输入可灵活调整)
transforms.ToTensor(), # 转为PyTorch张量(维度:C×H×W,值0-1)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
为什么要预处理?
1. 不同图片大小不一,模型无法直接处理,需统一尺寸;
2. 张量是PyTorch的核心数据格式,必须转换;
3. 标准化能让模型更快收敛(使用ImageNet数据集的均值/标准差,适配预训练模型)。
4.3 模型训练函数
def train_model(data_dir, num_epochs=10, batch_size=32, save_path='cat_dog_model.pth'):
"""
训练模型并保存
:param data_dir: 训练数据路径(含cats/dogs子文件夹)
:param num_epochs: 训练轮数(默认10)
:param batch_size: 批次大小(默认32,GPU内存小则改小)
:param save_path: 模型保存路径
"""
# 1. 加载训练数据(自动按文件夹划分类别)
train_data = datasets.ImageFolder(
root=data_dir,
transform=transform
)
print(f"类别映射关系: {train_data.class_to_idx}") # 输出:{'cats':0, 'dogs':1}(0=猫,1=狗)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) # 打乱数据
# 2. 加载ResNet18预训练模型(迁移学习核心)
model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# 修改输出层:ResNet18默认输出1000类(ImageNet),这里改为2类(猫/狗)
model.fc = nn.Linear(model.fc.in_features, 2)
# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失(适合分类任务)
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器(学习率0.001)
# 4. 开始训练
for epoch in range(num_epochs):
model.train() # 设为训练模式(启用dropout/batchnorm)
running_loss = 0.0 # 累计损失
correct = 0 # 正确预测数
total = 0 # 总样本数
# 遍历训练集的每个批次
for images, labels in train_loader:
optimizer.zero_grad() # 清除上一轮梯度(避免累积)
outputs = model(images) # 前向传播:输入图片→得到预测结果
loss = criterion(outputs, labels) # 计算损失(预测结果vs真实标签)
loss.backward() # 反向传播:计算梯度
optimizer.step() # 更新模型参数
running_loss += loss.item() # 累计损失
_, predicted = torch.max(outputs, 1) # 获取预测类别(0或1)
total += labels.size(0) # 累计总样本数
correct += (predicted == labels).sum().item() # 累计正确数
# 计算并输出当前轮的结果
avg_loss = running_loss / len(train_loader) # 平均损失
accuracy = 100 * correct / total # 整体准确率
print(f'轮次 [{epoch + 1}/{num_epochs}], 平均损失: {avg_loss:.4f}, 准确率: {accuracy:.2f}%')
# 准确率100%时提前停止(避免过拟合,可选)
if accuracy == 100:
break
# 保存模型(只保存参数,不保存整个模型,占用空间小)
torch.save(model.state_dict(), save_path)
print(f"模型已保存至: {save_path}")
4.4 单张图片预测辅助函数
def predict_single_image(model, img_path):
"""
辅助函数:用加载好的模型预测单张图片
:param model: 加载好的模型实例
:param img_path: 单张图片路径
:return: 预测标签(0/1)、预测文本(猫/狗)
"""
try:
# 读取图片并转为RGB(避免灰度图报错)
img = Image.open(img_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0) # 增加batch维度(模型要求输入是批量数据)
# 预测(关闭梯度计算,加快速度)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
pred_label = predicted.item()
pred_text = "猫" if pred_label == 0 else "狗"
return pred_label, pred_text
except Exception as e:
# 异常处理:避免单张图片损坏导致整个批量预测中断
print(f"⚠️ 处理图片 {img_path} 时出错: {str(e)}")
return None, None
为什么要封装成辅助函数?
批量预测需要重复调用单张预测逻辑,封装后代码更简洁、可复用,同时便于集中处理异常。
4.5 批量预测核心函数(重点!)
def batch_predict_images(model_path, test_dir):
"""
批量预测指定文件夹下的所有图片(含cats、dogs子文件夹)
:param model_path: 训练好的模型路径
:param test_dir: 测试数据根目录(含cats、dogs子文件夹)
:return: 整体预测准确率
"""
# 1. 加载训练好的模型
model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(model_path)) # 加载模型参数
model.eval() # 设为评估模式(关闭dropout/batchnorm)
print(f"✅ 模型加载完成: {model_path}")
# 2. 类别映射(与训练时一致)
class_to_idx = {'cats': 0, 'dogs': 1}
idx_to_class = {v: k for k, v in class_to_idx.items()} # 反向映射:0→cats,1→dogs
# 3. 初始化统计指标
total_samples = 0 # 总图片数
correct_predictions = 0 # 正确预测数
error_samples = [] # 处理失败的图片路径
# 4. 遍历test_dir下的cats和dogs子文件夹
for class_name in ['cats', 'dogs']:
class_dir = os.path.join(test_dir, class_name) # 拼接子文件夹路径
if not os.path.exists(class_dir):
print(f"⚠️ 未找到文件夹: {class_dir},跳过")
continue
# 筛选出图片文件(避免非图片文件干扰)
img_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if not img_files:
print(f"⚠️ 文件夹 {class_dir} 内无图片")
continue
true_label = class_to_idx[class_name] # 当前文件夹的真实标签(0=猫,1=狗)
print(f"\n========== 开始预测 {class_name} 文件夹(真实标签: {true_label})==========")
# 逐个预测该文件夹下的图片
for img_file in img_files:
img_path = os.path.join(class_dir, img_file) # 拼接图片完整路径
pred_label, pred_text = predict_single_image(model, img_path)
total_samples += 1
# 处理预测结果
if pred_label is None:
error_samples.append(img_path)
continue
# 输出单张图片预测结果
print(f"图片: {img_file} | 真实类别: {class_name} | 预测类别: {pred_text} | "
f"{'✅ 正确' if pred_label == true_label else '❌ 错误'}")
# 累计正确数
if pred_label == true_label:
correct_predictions += 1
# 5. 输出批量预测汇总
if total_samples == 0:
print("\n❌ 未找到任何可预测的图片")
return 0.0
# 计算准确率(排除处理失败的图片)
valid_samples = total_samples - len(error_samples)
overall_accuracy = 100 * correct_predictions / valid_samples
print(f"\n========== 批量预测结果汇总 ==========")
print(f"总图片数: {total_samples}")
print(f"有效预测数: {valid_samples}")
print(f"正确预测数: {correct_predictions}")
print(f"整体准确率: {overall_accuracy:.2f}%")
if error_samples:
print(f"处理失败图片数: {len(error_samples)}")
print(f"失败图片列表: {error_samples}")
return overall_accuracy
批量预测核心逻辑:
-
加载训练好的模型,设为评估模式;
-
遍历test_dir下的cats和dogs子文件夹;
-
筛选图片文件,避免非图片(如.txt)干扰;
-
调用辅助函数预测单张图片,输出详细结果;
-
统计整体准确率,记录失败图片(便于排查)。
4.6 主函数(程序入口)
if __name__ == "__main__":
# ===================== 可选:训练模型(已训练可注释) =====================
# data_dir = r"E:\Desktop\AI_structure\猫狗实训\catdog\data\train" # 你的训练数据路径
# train_model(data_dir, num_epochs=10, batch_size=32)
# ===================== 批量预测(核心功能) =====================
model_path = 'cat_dog_model.pth' # 训练好的模型路径
test_dir = r"E:\Desktop\AI_structure\猫狗实训\catdog\data\test" # 你的测试数据根目录(含cats/dogs)
batch_predict_images(model_path, test_dir)
说明:训练模型只需运行一次,生成cat_dog_model.pth后,后续预测可注释训练代码,直接运行批量预测。
五、运行步骤(手把手操作)
5.1 第一步:准备数据
按第三部分的目录结构,将训练图放入train/cats、train/dogs,测试图放入test/cats、test/dogs。
5.2 第二步:修改路径
在主函数中,修改以下两个路径为自己的实际路径(注意路径中的反斜杠“\”要保留,或用“/”替代):
# 训练数据路径(训练时启用)
data_dir = r"你的训练数据路径,如E:\catdog\data\train"
# 测试数据路径(批量预测时必须修改)
test_dir = r"你的测试数据路径,如E:\catdog\data\test"
5.3 第三步:训练模型(首次运行)
取消主函数中训练相关代码的注释,运行程序。控制台会输出:
类别映射关系: {'cats': 0, 'dogs': 1}
轮次 [1/10], 平均损失: 0.3256, 准确率: 85.23%
轮次 [2/10], 平均损失: 0.1872, 准确率: 92.15%
...
模型已保存至: cat_dog_model.pth
训练完成后,项目根目录会生成cat_dog_model.pth文件。
5.4 第四步:批量预测
注释训练代码,运行批量预测部分。控制台会输出:
✅ 模型加载完成: cat_dog_model.pth
========== 开始预测 cats 文件夹(真实标签: 0)==========
图片: cat_1.jpg | 真实类别: cats | 预测类别: 猫 | ✅ 正确
图片: cat_2.jpg | 真实类别: cats | 预测类别: 猫 | ✅ 正确
图片: cat_3.jpg | 真实类别: cats | 预测类别: 狗 | ❌ 错误
========== 开始预测 dogs 文件夹(真实标签: 1)==========
图片: dog_1.jpg | 真实类别: dogs | 预测类别: 狗 | ✅ 正确
图片: dog_2.jpg | 真实类别: dogs | 预测类别: 狗 | ✅ 正确
========== 批量预测结果汇总 ==========
总图片数: 5
有效预测数: 5
正确预测数: 4
整体准确率: 80.00%
处理失败图片数: 0
六、结果解读
-
训练时:“平均损失”越低越好,“准确率”越高越好(一般训练10轮后准确率能达到90%以上);
-
预测时:每张图片会显示“真实类别”和“预测类别”,用✅/❌标记是否正确;
-
汇总结果:重点看“整体准确率”,若准确率较低,可通过“错误图片列表”查看哪些图片预测错了,分析原因(如图片模糊、角度特殊等)。
七、常见问题与解决方案
7.1 报错“找不到文件/文件夹”
解决方案:检查data_dir、test_dir路径是否正确,确保文件夹结构符合要求(含cats、dogs子文件夹)。
7.2 训练时显存不足
解决方案:减小batch_size(如改为16、8),或使用CPU训练(PyTorch会自动适配)。
7.3 部分图片处理失败
解决方案:查看失败图片列表,检查图片是否损坏、格式是否正确(建议用主流的.jpg/.png格式)。
7.4 准确率过低(低于80%)
解决方案:
-
增加训练轮数(如改为20);
-
增加训练数据量(使用Kaggle完整数据集);
-
调整图像预处理(如增加数据增强:transforms.RandomCrop、transforms.RandomFlip)。
八、项目拓展方向
学会这个项目后,可以尝试以下拓展,提升技能:
-
优化模型:使用ResNet50/ResNet101等更复杂的预训练模型,进一步提高准确率;
-
数据增强:添加随机裁剪、翻转、旋转等操作,提升模型泛化能力;
-
添加验证集:在训练时划分验证集,监控模型是否过拟合;
-
模型部署:用Flask/Django搭建简单网页,实现上传图片预测;
-
多类别分类:扩展到区分更多动物(如猫、狗、鸡、鸭)。
九、总结
本项目通过“数据准备-模型训练-批量预测”的完整流程,带大家入门了计算机视觉的核心技术——迁移学习。重点修复了训练时的准确率计算bug,实现了实用的批量预测功能,同时给出了详细的操作步骤和问题解决方案,确保初学者能顺利上手。
完整代码已整理好,大家可以直接复制到本地运行。
点击查看完整代码
import torch
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from PIL import Image
from torchvision.models import ResNet18_Weights
import os
# 定义transform类(视觉转换类,将图片格式转化为张量格式)
transform = transforms.Compose([
transforms.Resize((128, 128)), # 将图片缩放到统一大小
transforms.ToTensor(), # 转换为Tensor格式
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化处理
])
def train_model(data_dir, num_epochs=10, batch_size=32, save_path='cat_dog_model.pth'):
"""
训练模型并保存(修复了原代码准确率计算错误)。
:param data_dir: 数据路径,包含cat和dog文件夹
:param num_epochs: 训练周期,默认为10
:param batch_size: 批次大小,默认为32
:param save_path: 模型保存路径,默认为'cat_dog_model.pth'
"""
# 1. 加载训练数据
train_data = datasets.ImageFolder(
root=data_dir, # 数据路径
transform=transform
)
print(f"类别映射关系: {train_data.class_to_idx}") # 输出文件夹编号,例如{'cat': 0, 'dog': 1}
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 2. 使用预训练的ResNet18模型
model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2) # 修改输出层以适应2类分类(猫、狗)
# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 开始训练模型
for epoch in range(num_epochs):
model.train() # 设置模型为训练模式
running_loss = 0.0 # 初始化损失值
correct = 0 # 模型预测准确数
total = 0 # 模型预测总数
for images, labels in train_loader:
optimizer.zero_grad() # 清除之前的梯度
outputs = model(images) # 前向传播,得出预测结果
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item()
# 【修复】将准确率计算移到batch循环内,累计所有样本的结果
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 计算并输出该epoch的损失和准确率
accuracy = 100 * correct / total
avg_loss = running_loss / len(train_loader)
print(f'周期 [{epoch + 1}/{num_epochs}], 平均损失: {avg_loss:.4f}, 准确率: {accuracy:.2f}%')
# 准确率100%时提前停止(实际场景中100%大概率是过拟合,仅作演示)
if accuracy == 100:
break
# 保存训练模型
torch.save(model.state_dict(), save_path)
print(f"模型已保存至: {save_path}")
def predict_single_image(model, img_path):
"""
辅助函数:使用加载好的模型预测单张图片
:param model: 加载完成的模型实例
:param img_path: 图片路径
:return: 预测标签(0=猫,1=狗)、预测结果文本
"""
try:
img = Image.open(img_path).convert('RGB') # 确保图片为RGB格式(处理灰度图)
img_tensor = transform(img).unsqueeze(0) # 增加batch维度
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
pred_label = predicted.item()
pred_text = "猫" if pred_label == 0 else "狗"
return pred_label, pred_text
except Exception as e:
print(f"⚠️ 处理图片 {img_path} 时出错: {str(e)}")
return None, None
def batch_predict_images(model_path, test_dir):
"""
批量预测指定文件夹下的所有图片(目录结构需包含cats/dogs子文件夹)
:param model_path: 训练好的模型路径
:param test_dir: 测试数据根目录(包含cats、dogs子文件夹)
:return: 整体预测准确率
"""
# 1. 加载模型
model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(model_path))
model.eval() # 设置为评估模式
print(f"✅ 模型加载完成: {model_path}")
# 2. 构建类别映射(反向映射,用于从标签获取类别名)
class_to_idx = {'cats': 0, 'dogs': 1}
idx_to_class = {v: k for k, v in class_to_idx.items()}
# 3. 遍历所有图片并预测
total_samples = 0
correct_predictions = 0
error_samples = []
# 遍历cats和dogs子文件夹
for class_name in ['cats', 'dogs']:
class_dir = os.path.join(test_dir, class_name)
if not os.path.exists(class_dir):
print(f"⚠️ 未找到文件夹: {class_dir},跳过该类别")
continue
# 获取该类别下的所有图片文件
img_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
if not img_files:
print(f"⚠️ 文件夹 {class_dir} 内无图片文件")
continue
true_label = class_to_idx[class_name]
print(f"\n========== 开始预测 {class_name} 文件夹(真实标签: {true_label})==========")
# 逐个预测图片
for img_file in img_files:
img_path = os.path.join(class_dir, img_file)
pred_label, pred_text = predict_single_image(model, img_path)
total_samples += 1
# 统计结果
if pred_label is None:
error_samples.append(img_path)
continue
# 输出单张图片预测结果
print(f"图片: {img_file} | 真实类别: {class_name} | 预测类别: {pred_text} | "
f"{'✅ 正确' if pred_label == true_label else '❌ 错误'}")
if pred_label == true_label:
correct_predictions += 1
# 4. 计算整体准确率
if total_samples == 0:
print("\n❌ 未找到任何可预测的图片")
return 0.0
overall_accuracy = 100 * correct_predictions / (total_samples - len(error_samples))
print(f"\n========== 批量预测结果汇总 ==========")
print(f"总图片数: {total_samples}")
print(f"有效预测数: {total_samples - len(error_samples)}")
print(f"预测正确数: {correct_predictions}")
print(f"整体准确率: {overall_accuracy:.2f}%")
if error_samples:
print(f"处理失败的图片数: {len(error_samples)}")
print(f"失败图片列表: {error_samples}")
return overall_accuracy
if __name__ == "__main__":
# ===================== 可选:训练模型(已训练可注释) =====================
# data_dir = r"E:\Desktop\AI_structure\猫狗实训\catdog\data\train" # 训练数据路径
# train_model(data_dir, num_epochs=10, batch_size=32)
# ===================== 批量预测(核心功能) =====================
model_path = 'cat_dog_model.pth' # 训练好的模型路径
test_dir = r"E:\Desktop\AI_structure\猫狗实训\catdog\data\test" # 测试数据根目录(含cats/dogs子文件夹)
batch_predict_images(model_path, test_dir)

浙公网安备 33010602011771号