人工智能大作业:植物病害检测系统

本项目使用卷积神经网络算法实现了植物病害检测系统,下面我将以代码来详细说明实现思路

首先,本项目核心算法就是Resnet 50+迁移学习+数据增强

我使用了公共数据集PlantVillage / New Plant Diseases,在该数据集上进行训练,实现植物叶子病害的自动诊断。

数据预处理:dataset.py

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return train_transform, val_transform


def get_dataloaders(data_dir='data', batch_size=32, num_workers=4):
    train_transform, val_transform = get_transforms()
    
    # ==================== 情况B 的正确路径 ====================
    train_path = f'{data_dir}/train'
    val_path   = f'{data_dir}/valid'
    # ======================================================
    
    train_dataset = datasets.ImageFolder(root=train_path, transform=train_transform)
    val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=num_workers, pin_memory=True)
    
    print(f" 数据集加载完成")
    print(f"类别数量: {len(train_dataset.classes)}")
    print(f"训练集图片数量: {len(train_dataset)}")
    print(f"验证集图片数量: {len(val_dataset)}")
    print(f"第一个类别示例: {train_dataset.classes[0]}")
    
    return train_loader, val_loader, train_dataset.classes

  • RandomResizedCrop(224):随机裁剪并缩放到 224×224(ResNet 输入标准尺寸)
  • RandomHorizontalFlip():随机水平翻转,模拟叶子不同方向
  • RandomRotation(15):随便挑一个小角度旋转
  • ColorJitter:调整亮度、对比度、饱和度、色调,增加对光照变化的鲁棒性(鲁棒性是指是指一个计算机系统在执行过程中处理错误,以及算法在遭遇输入、运算等异常时维持正常运行的能力。简单来说就是稳定性,参数来自于维基百科)
  • Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):使用 ImageNet 预训练均值和标准差(迁移学习必须一致)
  • 训练时随机变换,保证训练质量,同时防止过拟合
  • 但是验证的时候要使用固定裁剪模式,保证验证成果的稳定性

 

定义获取神经网络模型:model.py

 

import torch.nn as nn
from torchvision import models

def get_model(num_classes=38,model_name='resnet50'):
    if model_name == 'resnet50':
        model = models.resnet50(weights='IMAGENET1K_V1')
        model.fc = nn.Linear(model.fc.in_features,num_classes)
    elif model_name == 'mobilenet_v3':
        model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')
        model.classifier[3] = nn.Linear(model.classifier[3].in_features,num_classes)
    return model    

 

  • 残差网络(Residual Network):核心创新是残差连接(Shortcut Connection),解决深度网络的“退化问题”(Degradation Problem)和梯度消失。

    残差连接的核心思想是引入一个“快捷连接”(shortcut connection)或“跳跃连接”(skip connection),允许数据绕过一些层直接传播。这样,网络中的一部分可以直接学习到输入与输出之间的残差(即差异),而不是直接学习到映射本身。具体来说,如果我们希望学习的目标映射是 H(x),我们让网络学习残差映射 F(x)=H(x)−x。因此,原始的目标映射可以表示为 F(x)+x。

  • 公式:y = F(x) + x(残差块),让网络更容易学习恒等映射。
  • ResNet50 有 50 层,包含多个 Bottleneck 残差块。
  • 迁移学习:利用在 ImageNet(1400 万张图片)上预训练的权重,只替换最后的全连接层(fc),大幅减少训练时间和数据需求,同时获得优秀的特征提取能力。

为什么选 ResNet50?

  • 精度高、结构成熟、在医学/农业图像任务中表现优秀
  • 参数量适中(约 25M),适合大多数显卡

 

具体训练过程:train.py

import torch
from torch import nn, optim
from tqdm import tqdm
import torchmetrics
from src.dataset import get_dataloaders
from src.model import get_model
import os

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    train_loader, val_loader, class_names = get_dataloaders(batch_size=64)
    model = get_model(num_classes=len(class_names), model_name='resnet50')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    
    best_acc = 0.0
    os.makedirs('models', exist_ok=True)
    
    for epoch in range(20):  # 可根据需要增加 epochs
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # 验证
        model.eval()
        accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(class_names)).to(device)
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                preds = outputs.argmax(dim=1)
                accuracy.update(preds, labels)
        
        acc = accuracy.compute().item()
        scheduler.step()
        
        print(f"Epoch {epoch+1}/20 | Loss: {running_loss/len(train_loader):.4f} | Val Acc: {acc:.4f}")
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'models/best_plant_disease.pth')
            print(" 保存最佳模型")
    
    print(f"训练完成!最佳验证准确率: {best_acc:.4f}")
    return model, class_names
    import json
    class_names_path = 'models/class_names.json'
    with open(class_names_path, 'w', encoding='utf-8') as f:
        json.dump(class_names, f, ensure_ascii=False, indent=2)
    
    print(f"class_names 已保存到 {class_names_path}")
    return model, class_names
# 在 src/train.py 的 train() 函数末尾(训练完成后)添加:

    # 保存 class_names

    

 

作用:完整训练流程、模型保存、日志记录。

 

使用的算法和技术

 

  1. 损失函数:CrossEntropyLoss(交叉熵损失)
    • 多分类任务的标准损失
  2. 优化器:AdamW
    • Adam + Weight Decay(权重衰减)
    • 比传统 Adam 更好地处理正则化,防止过拟合
  3. 学习率调度器:CosineAnnealingLR
    • 余弦退火:学习率按余弦曲线下降,能帮助模型在后期更精细收敛(模拟退火:模拟物理降温过程参数的变化,基于能量最低原理,可以让参数以随机且靠近最小的方法改变,适合解决单峰问题,不适合多峰问题)
  4. 评估指标:torchmetrics.Accuracy(多分类准确率)
  5. 训练流程
    • model.train() / model.eval() —— 切换模式(影响 BatchNorm、Dropout)
    • torch.no_grad() —— 验证时关闭梯度计算,节省显存
    • optimizer.zero_grad() → loss.backward() → optimizer.step()(标准反向传播)
  6. Early Saving:只保存验证准确率最好的模型(防止过拟合)

外部模型接口:train.py

from src.train import train

if __name__ == "__main__":
    model, classes = train()

外部模型接口,方便调用

 

部署模块:app.py

import gradio as gr
import torch
from PIL import Image
import json

from src.model import get_model
from src.dataset import get_transforms

# ==================== 全局加载模型和类别 ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载 class_names
with open('models/class_names.json', 'r', encoding='utf-8') as f:
    class_names = json.load(f)

# 加载模型
def load_model():
    num_classes = len(class_names)
    model = get_model(num_classes=num_classes, model_name='resnet50')
    model.load_state_dict(torch.load('models/best_plant_disease.pth', map_location=device))
    model.to(device)
    model.eval()
    return model

model = load_model()
_, val_transform = get_transforms()

# ==================== 预测函数 ====================
def predict_image(image):
    if image is None:
        return "请上传图片"
    
    # 预处理
    input_tensor = val_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output[0], dim=0)
        confidence, predicted_idx = torch.max(probabilities, 0)
    
    predicted_class = class_names[predicted_idx.item()]
    confidence_pct = confidence.item() * 100
    
    if "healthy" in predicted_class.lower():
        result = f" **健康**\n\n类别:{predicted_class}\n置信度:{confidence_pct:.2f}%"
    else:
        result = f"**疑似病害**\n\n类别:{predicted_class}\n置信度:{confidence_pct:.2f}%"
    
    return result

# ==================== Gradio 界面 ====================
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="上传植物叶子照片"),
    outputs=gr.Textbox(label="诊断结果"),
    title=" 植物病害智能检测系统",
    description="上传一张叶子照片,AI 将帮助你判断是否生病及病害类型",
    examples=[["examples/healthy.jpg"], ["examples/diseased.jpg"]],  # 可选
    allow_flagging="never"
)

if __name__ == "__main__":
    interface.launch(share=False)   # share=True 可生成公网链接

提供使用模型入口

 

posted @ 2026-05-16 21:50  Noname_min  阅读(25)  评论(0)    收藏  举报