基于 DeepLabV3 的玉米植株语义分割与提取

目录


一、背景与动机

在农业图像分析领域,准确分割植物是一个基础性工作。玉米植株的自动分割能为品种选育、病虫害检测、产量预测等提供数据支持。

  • 传统的阈值分割容易受到光照、背景干扰
  • 深度学习模型(如本文的DeepLabV3)可获得更具鲁棒性的分割结果
  • 利用辅助条件(本文用骨架线)进一步剔除误分区域,提高分割精度

💡 提示:确保骨架线与原图一一对应,才能有效辅助后处理。


二、环境配置与技术支持

  • Python:3.9.13
  • 深度学习框架:PyTorch + torchvision
  • 分割模型:DeepLabV3 + ResNet50
  • 辅助技术:骨架线过滤、连通域分析
  • 图像处理:OpenCV
  • 硬件加速:CUDA + GPU

Python 库

pip install torch torchvision opencv-python matplotlib numpy os glob

🛠️ 注意:GPU 与 CUDA 版本需与 PyTorch 官方兼容,安装前务必核对。


三、数据准备

项目目录结构示例:

project_root/
├─ train/                # 训练相关脚本
│  └─ train.py
├─ inference/            # 推理相关脚本
│  └─ inference.py
├─ data/
│  ├─ image/             # 训练图像(RGB,512×512)
│  ├─ mask/              # 训练掩码(二值图,同名文件)
│  └─ skeleton/          # 骨架图(灰度,用于后处理,辅助条件可不用)
└─ output/               # 分割输出目录

📂 注意:image、mask、skeleton 中的文件名必须一一对应,否则会报错。


四、数据集定义

train/train.py 中定义自定义数据集 CornDataset

class CornDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        # 图像和mask同名
        self.image_list = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        filename = self.image_list[idx]
        image_path = os.path.join(self.image_dir, filename)
        mask_path = os.path.join(self.mask_dir, filename)
        # 读取原图和mask
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        # 将BGR转换为RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # 掩码二值化
        mask = (mask // 255).astype(np.uint8)
        # 转换为Tensor
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask_tensor = torch.from_numpy(mask).long()
        return image_tensor, mask_tensor

📦 说明:上述代码创建了一个自定义Dataset类,处理图像和对应的掩码。主要步骤包括图像读取、颜色空间转换、掩码二值化和张量转换。


五、模型构建与微调

train/train.py 中加载并修改预训练模型:

# 加载预训练的DeepLabV3模型
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
# 修改输出类别为2(背景和植株)
model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=1)
# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 冻结BatchNorm层
for m in model.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        m.eval()
        for param in m.parameters():
            param.requires_grad = False

⚙️ 技术要点:冻结BatchNorm层是小批量训练时的常用技巧,可以提高训练稳定性。因为BatchNorm需要足够多的样本来计算统计量,而小批量训练时样本数不足,容易导致统计量不稳定。


六、数据加载

train/train.py 中编写数据加载:

# 设置数据路径
image_dir = "path/to/your/images"
mask_dir = "path/to/your/masks"
# 创建数据集和加载器
dataset = CornDataset(image_dir, mask_dir)
train_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

🔄 提示:根据 GPU 显存适当调整 batch_size,以避免 OOM 错误。


七、模型训练

train/train.py 中编写训练循环:

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 200

for epoch in range(num_epochs):
    model.train()  # 设置为训练模式
    # 确保BatchNorm层保持在评估模式
    model.apply(lambda m: m.eval() if isinstance(m, torch.nn.BatchNorm2d) else None)
    total_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        # 前向传播
        outputs = model(images)["out"]
        loss = criterion(outputs, masks)
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}: Loss = {total_loss:.4f}")
    # 每10个epoch保存一次模型
    if ((epoch + 1) % 10 == 0):
        torch.save(model.state_dict(), f"corn_deeplabv3_1024x1024_{epoch+1}.pth")

训练结果输出以及模型保存如下图所示:
训练输出图
模型保存图

🚀 建议:建议同时使用验证集,保存验证指标最优的模型。
💡提示:训练完成后,可使用 Matplotlib 绘制 Loss 曲线并保存为 loss_curve.png
📈 训练技巧:这里每10个epoch保存一次模型,便于选择最佳模型。实际应用中可以设置验证集,保存验证效果最好的模型。


八、模型应用

训练完成后,我们需要应用模型进行玉米植株分割。在 inference/inference.py 中实现:

8.1 模型加载函数

def load_model(weights_path):
    """加载DeepLabV3语义分割模型,并载入权重"""
    model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False)
    model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=1)  # 设定类别数

    # 加载权重到指定设备
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()
    return model

💡 提示:确保 weights_path 与本地路径一致,否则会报错。

8.2 图像预处理函数

def preprocess_image(image_path):
    """读取图像,转换为PyTorch tensor"""
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"无法读取图像: {image_path}")

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR -> RGB
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0  # 归一化
    image_tensor = image_tensor.unsqueeze(0).to(device)  # 增加batch维度
    return image_tensor

🧪 提示:连通域分析可有效过滤误检小区域,保留与骨架重叠部分。

8.3 预测与骨架线辅助优化函数

这是本文的核心创新点,通过骨架线信息过滤预测结果,提高分割质量:

def predict_and_refine(model, image_tensor, skeleton_path, min_overlap=1):
    """预测掩码,并结合骨架线进行后处理"""
    # 1) 模型预测
    with torch.no_grad():
        output = model(image_tensor)["out"]
        predicted_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    # 2) 读取骨架线
    skeleton_img = cv2.imread(skeleton_path, cv2.IMREAD_GRAYSCALE)
    if skeleton_img is None:
        raise FileNotFoundError(f"无法读取骨架图: {skeleton_path}")

    # 二值化处理
    _, skeleton_bin = cv2.threshold(skeleton_img, 10, 1, cv2.THRESH_BINARY)
    pred_bin = (predicted_mask == 1).astype(np.uint8)  # 预测的玉米区域

    # 3) 连通域分析 + 骨架线筛选
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_bin, connectivity=8)
    refined_mask = np.zeros_like(pred_bin, dtype=np.uint8)
    
    for label_id in range(1, num_labels):
        component_mask = (labels == label_id)
        overlap = np.logical_and(component_mask, skeleton_bin)
        if np.sum(overlap) >= min_overlap:
            refined_mask[component_mask] = 1  # 通过筛选的区域

    return refined_mask

🔍 技术解析:该方法首先使用模型预测出可能的玉米区域,然后通过连通域分析将预测结果分解为独立区域,最后检查每个区域是否与骨架线有足够重叠。这种方法可以有效过滤误检区域,保留真正的玉米植株区域。

8.4 主处理流程

if __name__ == "__main__":
    # 路径设置
    weights_path = "path/to/your/model.pth"     # 模型权重
    image_dir = "path/to/your/images"           # 原始图像文件夹
    skeleton_dir = "path/to/skeleton/images"    # 骨架线文件夹
    save_mask_dir = "path/to/save/masks"        # 输出掩码保存文件夹
    save_extract_dir = "path/to/save/extracts"  # 玉米植株提取后保存文件夹

    # 创建输出目录
    os.makedirs(save_mask_dir, exist_ok=True)
    os.makedirs(save_extract_dir, exist_ok=True)

    # 加载模型
    model = load_model(weights_path)

    # 获取图像列表
    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))

    # 逐张处理
    for i, img_path in enumerate(image_paths, start=1):
        file_name = os.path.basename(img_path)
        skeleton_path = os.path.join(skeleton_dir, file_name)
        print(f"[{i}/{len(image_paths)}] 正在处理: {file_name}")

        if not os.path.exists(skeleton_path):
            print(f"   - 找不到骨架图: {skeleton_path}, 跳过。")
            continue

        # 预处理
        image_tensor = preprocess_image(img_path)

        # 预测与优化
        refined_mask = predict_and_refine(model, image_tensor, skeleton_path, min_overlap=1)

        # 保存掩码
        mask_save_path = os.path.join(save_mask_dir, file_name)
        cv2.imwrite(mask_save_path, refined_mask * 255)

        # 抠出玉米区域
        orig_img = cv2.imread(img_path)
        if orig_img is None:
            print("   - 读取原图失败, 跳过。")
            continue
            
        plant_extract = cv2.bitwise_and(orig_img, orig_img, mask=refined_mask)
        extract_save_path = os.path.join(save_extract_dir, file_name)
        cv2.imwrite(extract_save_path, plant_extract)

        # 可视化(可选)
        # 在这里添加可视化代码,例:visualize_results(orig_img, refined_mask, plant_extract)

九、可视化效果

下面展示一些处理结果:

玉米分割效果图

左侧为原始图像,中间为分割掩膜,右侧为提取出的玉米植株

使用以下代码可以实现简单的可视化:

def visualize_results(original, mask, extracted):
    """可视化原图、掩码和提取结果"""
    plant_extract_rgb = cv2.cvtColor(plant_extract, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title("原始图像")
    plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.title("分割掩码")
    plt.imshow(mask, cmap="gray")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.title("提取结果")
    plt.imshow(cv2.cvtColor(extracted, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    
    plt.tight_layout()
    plt.show()

📷 提示:保存高分辨率截图。


十、实现难点与解决方案

10.1 小批量训练的稳定性问题

难点:在小批量训练时,BatchNorm层的统计量更新不稳定。

解决方案:冻结预训练模型的BatchNorm层,保持其在评估模式,同时允许其他层正常训练。代码实现:

# 训练前冻结BatchNorm
for m in model.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        m.eval()
        for param in m.parameters():
            param.requires_grad = False

# 训练中保持BatchNorm冻结
model.train()  
model.apply(lambda m: m.eval() if isinstance(m, torch.nn.BatchNorm2d) else None)

10.2 误检区域过滤

难点:模型可能会将一些非玉米区域误判为玉米。

解决方案:利用骨架线信息进行后处理,只保留与骨架线有交集的区域。这种方法基于一个假设:真正的玉米植株区域应当与骨架线有所重叠。

🔧 提示:可将骨架筛选逻辑封装为独立函数,方便复用。


十一、完整代码

完整代码可以在我的GitHub仓库查看:GitHub链接

💻 提示:建议在本地先跑通小样本,再批量处理大数据集,以节省调试时间。
如果需要数据集可私信或者评论!!!
如有问题,欢迎在评论区留言讨论!

posted @ 2025-04-30 14:09  执笔献江山  阅读(110)  评论(0)    收藏  举报