基于 DeepLabV3 的玉米植株语义分割与提取
目录
一、背景与动机
在农业图像分析领域,准确分割植物是一个基础性工作。玉米植株的自动分割能为品种选育、病虫害检测、产量预测等提供数据支持。
- 传统的阈值分割容易受到光照、背景干扰
- 深度学习模型(如本文的DeepLabV3)可获得更具鲁棒性的分割结果
- 利用辅助条件(本文用骨架线)进一步剔除误分区域,提高分割精度
💡 提示:确保骨架线与原图一一对应,才能有效辅助后处理。
二、环境配置与技术支持
- Python:3.9.13
- 深度学习框架:PyTorch + torchvision
- 分割模型:DeepLabV3 + ResNet50
- 辅助技术:骨架线过滤、连通域分析
- 图像处理:OpenCV
- 硬件加速:CUDA + GPU
Python 库
pip install torch torchvision opencv-python matplotlib numpy os glob
🛠️ 注意:GPU 与 CUDA 版本需与 PyTorch 官方兼容,安装前务必核对。
三、数据准备
项目目录结构示例:
project_root/
├─ train/ # 训练相关脚本
│ └─ train.py
├─ inference/ # 推理相关脚本
│ └─ inference.py
├─ data/
│ ├─ image/ # 训练图像(RGB,512×512)
│ ├─ mask/ # 训练掩码(二值图,同名文件)
│ └─ skeleton/ # 骨架图(灰度,用于后处理,辅助条件可不用)
└─ output/ # 分割输出目录
📂 注意:image、mask、skeleton 中的文件名必须一一对应,否则会报错。
四、数据集定义
在 train/train.py
中定义自定义数据集 CornDataset
:
class CornDataset(Dataset):
def __init__(self, image_dir, mask_dir):
self.image_dir = image_dir
self.mask_dir = mask_dir
# 图像和mask同名
self.image_list = os.listdir(image_dir)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
filename = self.image_list[idx]
image_path = os.path.join(self.image_dir, filename)
mask_path = os.path.join(self.mask_dir, filename)
# 读取原图和mask
image = cv2.imread(image_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# 将BGR转换为RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 掩码二值化
mask = (mask // 255).astype(np.uint8)
# 转换为Tensor
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
mask_tensor = torch.from_numpy(mask).long()
return image_tensor, mask_tensor
📦 说明:上述代码创建了一个自定义Dataset类,处理图像和对应的掩码。主要步骤包括图像读取、颜色空间转换、掩码二值化和张量转换。
五、模型构建与微调
在 train/train.py
中加载并修改预训练模型:
# 加载预训练的DeepLabV3模型
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
# 修改输出类别为2(背景和植株)
model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=1)
# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 冻结BatchNorm层
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
for param in m.parameters():
param.requires_grad = False
⚙️ 技术要点:冻结BatchNorm层是小批量训练时的常用技巧,可以提高训练稳定性。因为BatchNorm需要足够多的样本来计算统计量,而小批量训练时样本数不足,容易导致统计量不稳定。
六、数据加载
在 train/train.py
中编写数据加载:
# 设置数据路径
image_dir = "path/to/your/images"
mask_dir = "path/to/your/masks"
# 创建数据集和加载器
dataset = CornDataset(image_dir, mask_dir)
train_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
🔄 提示:根据 GPU 显存适当调整 batch_size,以避免 OOM 错误。
七、模型训练
在 train/train.py
中编写训练循环:
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 200
for epoch in range(num_epochs):
model.train() # 设置为训练模式
# 确保BatchNorm层保持在评估模式
model.apply(lambda m: m.eval() if isinstance(m, torch.nn.BatchNorm2d) else None)
total_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
# 前向传播
outputs = model(images)["out"]
loss = criterion(outputs, masks)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}: Loss = {total_loss:.4f}")
# 每10个epoch保存一次模型
if ((epoch + 1) % 10 == 0):
torch.save(model.state_dict(), f"corn_deeplabv3_1024x1024_{epoch+1}.pth")
训练结果输出以及模型保存如下图所示:
🚀 建议:建议同时使用验证集,保存验证指标最优的模型。
💡提示:训练完成后,可使用 Matplotlib 绘制 Loss 曲线并保存为loss_curve.png
。
📈 训练技巧:这里每10个epoch保存一次模型,便于选择最佳模型。实际应用中可以设置验证集,保存验证效果最好的模型。
八、模型应用
训练完成后,我们需要应用模型进行玉米植株分割。在 inference/inference.py
中实现:
8.1 模型加载函数
def load_model(weights_path):
"""加载DeepLabV3语义分割模型,并载入权重"""
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False)
model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=1) # 设定类别数
# 加载权重到指定设备
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model
💡 提示:确保 weights_path 与本地路径一致,否则会报错。
8.2 图像预处理函数
def preprocess_image(image_path):
"""读取图像,转换为PyTorch tensor"""
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"无法读取图像: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # 归一化
image_tensor = image_tensor.unsqueeze(0).to(device) # 增加batch维度
return image_tensor
🧪 提示:连通域分析可有效过滤误检小区域,保留与骨架重叠部分。
8.3 预测与骨架线辅助优化函数
这是本文的核心创新点,通过骨架线信息过滤预测结果,提高分割质量:
def predict_and_refine(model, image_tensor, skeleton_path, min_overlap=1):
"""预测掩码,并结合骨架线进行后处理"""
# 1) 模型预测
with torch.no_grad():
output = model(image_tensor)["out"]
predicted_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
# 2) 读取骨架线
skeleton_img = cv2.imread(skeleton_path, cv2.IMREAD_GRAYSCALE)
if skeleton_img is None:
raise FileNotFoundError(f"无法读取骨架图: {skeleton_path}")
# 二值化处理
_, skeleton_bin = cv2.threshold(skeleton_img, 10, 1, cv2.THRESH_BINARY)
pred_bin = (predicted_mask == 1).astype(np.uint8) # 预测的玉米区域
# 3) 连通域分析 + 骨架线筛选
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_bin, connectivity=8)
refined_mask = np.zeros_like(pred_bin, dtype=np.uint8)
for label_id in range(1, num_labels):
component_mask = (labels == label_id)
overlap = np.logical_and(component_mask, skeleton_bin)
if np.sum(overlap) >= min_overlap:
refined_mask[component_mask] = 1 # 通过筛选的区域
return refined_mask
🔍 技术解析:该方法首先使用模型预测出可能的玉米区域,然后通过连通域分析将预测结果分解为独立区域,最后检查每个区域是否与骨架线有足够重叠。这种方法可以有效过滤误检区域,保留真正的玉米植株区域。
8.4 主处理流程
if __name__ == "__main__":
# 路径设置
weights_path = "path/to/your/model.pth" # 模型权重
image_dir = "path/to/your/images" # 原始图像文件夹
skeleton_dir = "path/to/skeleton/images" # 骨架线文件夹
save_mask_dir = "path/to/save/masks" # 输出掩码保存文件夹
save_extract_dir = "path/to/save/extracts" # 玉米植株提取后保存文件夹
# 创建输出目录
os.makedirs(save_mask_dir, exist_ok=True)
os.makedirs(save_extract_dir, exist_ok=True)
# 加载模型
model = load_model(weights_path)
# 获取图像列表
image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
# 逐张处理
for i, img_path in enumerate(image_paths, start=1):
file_name = os.path.basename(img_path)
skeleton_path = os.path.join(skeleton_dir, file_name)
print(f"[{i}/{len(image_paths)}] 正在处理: {file_name}")
if not os.path.exists(skeleton_path):
print(f" - 找不到骨架图: {skeleton_path}, 跳过。")
continue
# 预处理
image_tensor = preprocess_image(img_path)
# 预测与优化
refined_mask = predict_and_refine(model, image_tensor, skeleton_path, min_overlap=1)
# 保存掩码
mask_save_path = os.path.join(save_mask_dir, file_name)
cv2.imwrite(mask_save_path, refined_mask * 255)
# 抠出玉米区域
orig_img = cv2.imread(img_path)
if orig_img is None:
print(" - 读取原图失败, 跳过。")
continue
plant_extract = cv2.bitwise_and(orig_img, orig_img, mask=refined_mask)
extract_save_path = os.path.join(save_extract_dir, file_name)
cv2.imwrite(extract_save_path, plant_extract)
# 可视化(可选)
# 在这里添加可视化代码,例:visualize_results(orig_img, refined_mask, plant_extract)
九、可视化效果
下面展示一些处理结果:
左侧为原始图像,中间为分割掩膜,右侧为提取出的玉米植株
使用以下代码可以实现简单的可视化:
def visualize_results(original, mask, extracted):
"""可视化原图、掩码和提取结果"""
plant_extract_rgb = cv2.cvtColor(plant_extract, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("原始图像")
plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title("分割掩码")
plt.imshow(mask, cmap="gray")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.title("提取结果")
plt.imshow(cv2.cvtColor(extracted, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.tight_layout()
plt.show()
📷 提示:保存高分辨率截图。
十、实现难点与解决方案
10.1 小批量训练的稳定性问题
难点:在小批量训练时,BatchNorm层的统计量更新不稳定。
解决方案:冻结预训练模型的BatchNorm层,保持其在评估模式,同时允许其他层正常训练。代码实现:
# 训练前冻结BatchNorm
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
for param in m.parameters():
param.requires_grad = False
# 训练中保持BatchNorm冻结
model.train()
model.apply(lambda m: m.eval() if isinstance(m, torch.nn.BatchNorm2d) else None)
10.2 误检区域过滤
难点:模型可能会将一些非玉米区域误判为玉米。
解决方案:利用骨架线信息进行后处理,只保留与骨架线有交集的区域。这种方法基于一个假设:真正的玉米植株区域应当与骨架线有所重叠。
🔧 提示:可将骨架筛选逻辑封装为独立函数,方便复用。
十一、完整代码
完整代码可以在我的GitHub仓库查看:GitHub链接
💻 提示:建议在本地先跑通小样本,再批量处理大数据集,以节省调试时间。
如果需要数据集可私信或者评论!!!
如有问题,欢迎在评论区留言讨论!