例子2
检测画框,并且合并框,并且合并mask
按照框大小,然后融合重叠的框
没有合并mask前

合并后

没有合并mask前

合并后

没有合并mask前

合并后

import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time # 添加时间模块
#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# 记录总开始时间
total_start_time = time.time()
# 记录模型加载开始时间
model_load_start_time = time.time()
# Load the model
model = build_sam3_image_model(
checkpoint_path="/home/r9000k/v2_project/sam/sam3/assets/model/sam3.pt"
)
processor = Sam3Processor(model, confidence_threshold=0.5)
# 记录模型加载结束时间
model_load_end_time = time.time()
model_load_time = model_load_end_time - model_load_start_time
print(f"模型加载时间: {model_load_time:.3f} 秒")
# 记录单张检测开始时间
detection_start_time = time.time()
image_path = "testimage/微信图片_20251120225838_38.jpg"
image_path = "3.jpg"
# Load an image
image = Image.open(image_path)
inference_state = processor.set_image(image)
# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="building") #building,road and playground building car、people、bicycle
# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# 记录单张检测结束时间
detection_end_time = time.time()
detection_time = detection_end_time - detection_start_time
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"原始检测到 {len(masks)} 个分割结果")
print(f"掩码形状: {masks.shape}")
def calculate_iou(box1, box2):
"""计算两个边界框的IoU(交并比)"""
# 解包坐标
x1_1, y1_1, x1_2, y1_2 = box1
x2_1, y2_1, x2_2, y2_2 = box2
# 计算交集区域
xi1 = max(x1_1, x2_1)
yi1 = max(y1_1, y2_1)
xi2 = min(x1_2, x2_2)
yi2 = min(y1_2, y2_2)
# 计算交集面积
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
# 计算并集面积
box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
union_area = box1_area + box2_area - inter_area
# 避免除以零
if union_area == 0:
return 0.0
#iou_= inter_area / union_area
iou_2= inter_area / box2_area
iou_1= inter_area / box1_area
iou_=max(iou_2,iou_1)
return iou_
def calculate_mask_overlap(mask1, mask2):
"""计算两个掩码的重叠比例(基于mask1)"""
mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
# 计算交集和mask1的面积
intersection = np.logical_and(mask1_np, mask2_np)
mask1_area = np.sum(mask1_np)
if mask1_area == 0:
return 0.0
overlap_ratio = np.sum(intersection) / mask1_area
return overlap_ratio
def fuse_overlapping_masks_use_scores(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
"""
融合重叠的掩码和边界框
参数:
masks: 形状为 [N, 1, H, W] 的掩码张量
boxes: 形状为 [N, 4] 的边界框张量
scores: 形状为 [N] 的得分张量
iou_threshold: IoU阈值,用于判定边界框是否重叠
overlap_threshold: 掩码重叠阈值,用于判定是否融合
"""
if len(masks) == 0:
return masks, boxes, scores
# 转换为numpy数组进行处理(使用copy()避免负步长问题)
boxes_np = boxes.cpu().numpy().copy()
scores_np = scores.cpu().numpy().copy()
# 按得分降序排序
# 降序索引
sorted_indices = np.argsort(scores_np)[::-1]
# 根据索引重新调整顺序
boxes_sorted = boxes_np[sorted_indices]
scores_sorted = scores_np[sorted_indices]
# 处理masks:先转换为列表,然后按排序索引重新组织
masks_list = [masks[i] for i in range(len(masks))]
# 根据索引重新调整顺序
masks_sorted = [masks_list[i] for i in sorted_indices]
# 初始化保留索引
keep_indices = []
suppressed = set()
for i in range(len(boxes_sorted)):
print('=====================',i)
if i in suppressed:
print('1 跳过',i)
continue
keep_indices.append(i)
for j in range(i + 1, len(boxes_sorted)):
if j in suppressed:
print('2 跳过',i)
continue
# 计算IoU
iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
if iou > iou_threshold:
# 计算掩码重叠比例
#overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
#if overlap_ratio > overlap_threshold:
suppressed.add(j)
print(f"融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
#print(f" - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
else:
#keep_indices.append(i)
print(f"xxxxxx融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
# 获取保留的检测结果
final_indices = [sorted_indices[i] for i in keep_indices]
# 使用PyTorch的索引操作来获取最终结果
final_masks = torch.stack([masks_list[i] for i in final_indices])
final_boxes = boxes[final_indices]
final_scores = scores[final_indices]
print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
return final_masks, final_boxes, final_scores
def fuse_overlapping_masks_justIou(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
"""
融合重叠的掩码和边界框
参数:
masks: 形状为 [N, 1, H, W] 的掩码张量
boxes: 形状为 [N, 4] 的边界框张量
scores: 形状为 [N] 的得分张量
iou_threshold: IoU阈值,用于判定边界框是否重叠
overlap_threshold: 掩码重叠阈值,用于判定是否融合
"""
if len(masks) == 0:
return masks, boxes, scores
# 转换为numpy数组进行处理(使用copy()避免负步长问题)
boxes_np = boxes.cpu().numpy().copy()
scores_np = scores.cpu().numpy().copy()
# 计算每个边界框的面积 (w * h)
areas = []
for box in boxes_np:
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
area = width * height
areas.append(area)
areas_np = np.array(areas)
# 按面积降序排序(面积大的优先)
sorted_indices = np.argsort(areas_np)[::-1]
# 根据索引重新调整顺序
boxes_sorted = boxes_np[sorted_indices]
scores_sorted = scores_np[sorted_indices]
areas_sorted = areas_np[sorted_indices]
# 处理masks:先转换为列表,然后按排序索引重新组织
masks_list = [masks[i] for i in range(len(masks))]
# 根据索引重新调整顺序
masks_sorted = [masks_list[i] for i in sorted_indices]
# 初始化保留索引
keep_indices = []
suppressed = set()
for i in range(len(boxes_sorted)):
print('=====================', i)
if i in suppressed:
print('1 跳过', i)
continue
keep_indices.append(i)
for j in range(i + 1, len(boxes_sorted)):
if j in suppressed:
print('2 跳过', j)
continue
# 计算IoU
iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
if iou > iou_threshold:
# 计算掩码重叠比例
# overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
# if overlap_ratio > overlap_threshold:
suppressed.add(j)
print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 覆盖索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", " iou:", iou)
# print(f" - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
else:
print(f"xxxxxx未融合: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 和索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", " iou:", iou)
# 获取保留的检测结果
final_indices = [sorted_indices[i] for i in keep_indices]
# 使用PyTorch的索引操作来获取最终结果
final_masks = torch.stack([masks_list[i] for i in final_indices])
final_boxes = boxes[final_indices]
final_scores = scores[final_indices]
print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
return final_masks, final_boxes, final_scores
def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
"""
融合重叠的掩码和边界框
参数:
masks: 形状为 [N, 1, H, W] 的掩码张量
boxes: 形状为 [N, 4] 的边界框张量
scores: 形状为 [N] 的得分张量
iou_threshold: IoU阈值,用于判定边界框是否重叠
overlap_threshold: 掩码重叠阈值,用于判定是否融合
"""
if len(masks) == 0:
return masks, boxes, scores
# 转换为numpy数组进行处理(使用copy()避免负步长问题)
boxes_np = boxes.cpu().numpy().copy()
scores_np = scores.cpu().numpy().copy()
# 计算每个边界框的面积 (w * h)
areas = []
for box in boxes_np:
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
area = width * height
areas.append(area)
areas_np = np.array(areas)
# 按面积降序排序(面积大的优先)
sorted_indices = np.argsort(areas_np)[::-1]
# 根据索引重新调整顺序
boxes_sorted = boxes_np[sorted_indices]
scores_sorted = scores_np[sorted_indices]
areas_sorted = areas_np[sorted_indices]
# 处理masks:先转换为列表,然后按排序索引重新组织
masks_list = [masks[i] for i in range(len(masks))]
# 根据索引重新调整顺序
masks_sorted = [masks_list[i] for i in sorted_indices]
# 初始化保留索引和融合后的masks
keep_indices = []
suppressed = set()
fused_masks = masks_sorted.copy() # 用于存储融合后的masks
for i in range(len(boxes_sorted)):
print('=====================', i)
if i in suppressed:
print('1 跳过', i)
continue
keep_indices.append(i)
current_mask = fused_masks[i] # 当前要融合的大mask
for j in range(i + 1, len(boxes_sorted)):
if j in suppressed:
print('2 跳过', j)
continue
# 计算IoU
iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
if iou > iou_threshold:
# 计算掩码重叠比例
mask_overlap = calculate_mask_overlap(current_mask, masks_sorted[j])
suppressed.add(j)
# 将小mask合并到大mask上
fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
fused_masks[i] = fused_mask
current_mask = fused_mask # 更新当前mask
print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})",
" iou:", iou, " mask重叠:", mask_overlap)
# if mask_overlap > overlap_threshold:
# suppressed.add(j)
# # 将小mask合并到大mask上
# fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
# fused_masks[i] = fused_mask
# current_mask = fused_mask # 更新当前mask
# print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})",
# " iou:", iou, " mask重叠:", mask_overlap)
# else:
# print(f"mask重叠不足: 索引 {sorted_indices[i]} 和索引 {sorted_indices[j]}",
# " iou:", iou, " mask重叠:", mask_overlap)
else:
print(f"IoU不足: 索引 {sorted_indices[i]} 和索引 {sorted_indices[j]}",
" iou:", iou)
# 获取保留的检测结果(使用融合后的masks)
final_indices = [sorted_indices[i] for i in keep_indices]
final_masks_list = [fused_masks[i] for i in keep_indices]
# 使用PyTorch的索引操作来获取最终结果
final_masks = torch.stack(final_masks_list)
final_boxes = boxes[final_indices]
final_scores = scores[final_indices]
print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
return final_masks, final_boxes, final_scores
def fuse_two_masks(mask1, mask2):
"""
将两个mask融合(取并集)
参数:
mask1: 第一个mask张量 [1, H, W]
mask2: 第二个mask张量 [1, H, W]
返回:
融合后的mask张量 [1, H, W]
"""
# 使用逻辑或操作合并两个mask
fused_mask = torch.logical_or(mask1, mask2).float()
return fused_mask
# 应用融合函数
print("\n开始融合重叠的检测结果...")
fusion_start_time = time.time()
# 调用融合函数
fused_masks, fused_boxes, fused_scores = fuse_overlapping_masks(
masks, boxes, scores,
iou_threshold=0.5, # 可以调整这个阈值
overlap_threshold=0.6 # 可以调整这个阈值
)
fusion_time = time.time() - fusion_start_time
print(f"融合完成时间: {fusion_time:.3f} 秒")
def overlay_masks_with_info(image, masks, boxes, scores, fusion_mode=False):
"""
在图像上叠加掩码,并添加ID、得分和矩形框
masks: 形状为 [N, 1, H, W] 的四维张量
boxes: 形状为 [N, 4] 的边界框张量 [x1, y1, x2, y2]
scores: 形状为 [N] 的得分张量
fusion_mode: 是否为融合后的模式(使用不同颜色)
"""
# 转换为RGB模式以便绘制
image = image.convert("RGB")
draw = ImageDraw.Draw(image)
# 尝试加载字体,如果失败则使用默认字体
try:
# 尝试使用系统中文字体
font = ImageFont.truetype("SimHei.ttf", 20)
except:
try:
font = ImageFont.truetype("Arial.ttf", 20)
except:
font = ImageFont.load_default()
# 将掩码转换为numpy数组并去除通道维度
masks_np = masks.cpu().numpy().astype(np.uint8) # 形状: [N, 1, H, W]
masks_np = masks_np.squeeze(1) # 移除通道维度,形状: [N, H, W]
boxes_np = boxes.cpu().numpy() # 形状: [N, 4]
scores_np = scores.cpu().numpy() # 形状: [N]
n_masks = masks_np.shape[0]
# 根据是否为融合模式选择不同的颜色映射
if fusion_mode:
cmap = plt.cm.get_cmap("viridis", n_masks) # 融合模式使用viridis配色
else:
cmap = plt.cm.get_cmap("rainbow", n_masks) # 原始模式使用rainbow配色
for i, (mask, box, score) in enumerate(zip(masks_np, boxes_np, scores_np)):
# 获取颜色
color = tuple(int(c * 255) for c in cmap(i)[:3])
# 确保掩码是二维的
if mask.ndim == 3:
mask = mask.squeeze(0)
# 创建透明度掩码
alpha_mask = (mask * 128).astype(np.uint8) # 0.5透明度
# 创建彩色覆盖层
overlay = Image.new("RGBA", image.size, color + (128,))
# 应用alpha通道
alpha = Image.fromarray(alpha_mask, mode='L')
overlay.putalpha(alpha)
# 叠加到图像上
image = Image.alpha_composite(image.convert("RGBA"), overlay).convert("RGB")
draw = ImageDraw.Draw(image)
# 绘制边界框
x1, y1, x2, y2 = box
# 确保坐标在图像范围内
x1 = max(0, min(x1, image.width))
y1 = max(0, min(y1, image.height))
x2 = max(0, min(x2, image.width))
y2 = max(0, min(y2, image.height))
# 绘制矩形框
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# 准备文本信息
if fusion_mode:
text = f"Fused-ID:{i} Score:{score:.3f}"
else:
text = f"ID:{i} Score:{score:.3f}"
# 计算文本位置(在框的上方)
try:
# 新版本的PIL
left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
text_width = right - left
text_height = bottom - top
except:
# 旧版本的PIL
text_width, text_height = draw.textsize(text, font=font)
text_x = x1
text_y = max(0, y1 - text_height - 5)
# 绘制文本背景
draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5],
fill=color)
# 绘制文本
draw.text((text_x + 5, text_y + 2), text, fill="white", font=font)
return image
# 记录可视化开始时间
visualization_start_time = time.time()
# 应用掩码叠加(原始结果)
original_image = Image.open(image_path)
result_image_original = overlay_masks_with_info(original_image, masks, boxes, scores, fusion_mode=False)
# 应用掩码叠加(融合后结果)
result_image_fused = overlay_masks_with_info(original_image, fused_masks, fused_boxes, fused_scores, fusion_mode=True)
# 保存结果图像
output_path_original = "segmentation_result_original.png"
output_path_fused = "segmentation_result_fused.png"
result_image_original.save(output_path_original)
result_image_fused.save(output_path_fused)
# 记录可视化结束时间
visualization_end_time = time.time()
visualization_time = visualization_end_time - visualization_start_time
print(f"可视化时间: {visualization_time:.3f} 秒")
print(f"原始分割结果已保存到: {output_path_original}")
print(f"融合后分割结果已保存到: {output_path_fused}")
# 设置中文字体或使用英文避免警告
try:
# 尝试设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
except:
pass
# 显示对比图像
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
# 原始结果
ax1.imshow(result_image_original)
ax1.set_title(f"原始结果: 检测到 {len(masks)} 个分割结果", fontsize=14)
ax1.axis('off')
# 融合后结果
ax2.imshow(result_image_fused)
ax2.set_title(f"融合后结果: 剩余 {len(fused_masks)} 个分割结果", fontsize=14)
ax2.axis('off')
plt.tight_layout()
plt.savefig("segmentation_comparison.png", bbox_inches='tight', dpi=300, facecolor='white')
plt.show()
# 记录总结束时间
total_end_time = time.time()
total_time = total_end_time - total_start_time
# 打印详细的时间统计
print("\n" + "="*50)
print("运行时间统计:")
print("="*50)
print(f"模型加载时间: {model_load_time:.3f} 秒")
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"融合处理时间: {fusion_time:.3f} 秒")
print(f"可视化时间: {visualization_time:.3f} 秒")
print("-"*50)
print(f"总运行时间: {total_time:.3f} 秒")
print("="*50)
def save_mask(masks_to_save, boxes_to_save, scores_to_save, prefix="mask"):
"""保存单个掩码的通用函数"""
print(f"\n保存{prefix}的单个掩码...")
for i, (mask, box, score) in enumerate(zip(masks_to_save, boxes_to_save, scores_to_save)):
# 创建单个掩码的可视化
base_image = Image.open(image_path).convert("RGB")
single_draw = ImageDraw.Draw(base_image)
# 尝试加载字体
try:
single_font = ImageFont.truetype("SimHei.ttf", 24)
except:
try:
single_font = ImageFont.truetype("Arial.ttf", 24)
except:
single_font = ImageFont.load_default()
# 处理掩码
mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
color = tuple(int(c * 255) for c in plt.cm.get_cmap("viridis", len(masks_to_save))(i)[:3])
# 创建透明度掩码
alpha_mask = (mask_np * 128).astype(np.uint8)
overlay = Image.new("RGBA", base_image.size, color + (128,))
alpha = Image.fromarray(alpha_mask, mode='L')
overlay.putalpha(alpha)
base_image = Image.alpha_composite(base_image.convert("RGBA"), overlay).convert("RGB")
single_draw = ImageDraw.Draw(base_image)
# 绘制边界框和文本
x1, y1, x2, y2 = box.cpu().numpy()
single_draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
text = f"ID:{i} Score:{score:.3f}"
try:
# 新版本的PIL
left, top, right, bottom = single_draw.textbbox((0, 0), text, font=single_font)
text_width = right - left
text_height = bottom - top
except:
# 旧版本的PIL
text_width, text_height = single_draw.textsize(text, font=single_font)
text_x = x1
text_y = max(0, y1 - text_height - 5)
single_draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5],
fill=color)
single_draw.text((text_x + 5, text_y + 2), text, fill="white", font=single_font)
base_image.save(f"{prefix}_with_info_{i:02d}.png")
print(f"保存{prefix} {i:02d}.png (得分: {score:.3f})")
# # 保存原始和融合后的单个掩码
# save_mask(masks, boxes, scores, "original_mask")
# save_mask(fused_masks, fused_boxes, fused_scores, "fused_mask")
print("所有处理完成!")
例子1
检测画框,并且合并
按照分数排序,然后融合重叠的框
缺点 丢失框



import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time # 添加时间模块
#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# 记录总开始时间
total_start_time = time.time()
# 记录模型加载开始时间
model_load_start_time = time.time()
# Load the model
model = build_sam3_image_model(
checkpoint_path="/home/r9000k/v2_project/sam/sam3/assets/model/sam3.pt"
)
processor = Sam3Processor(model, confidence_threshold=0.5)
# 记录模型加载结束时间
model_load_end_time = time.time()
model_load_time = model_load_end_time - model_load_start_time
print(f"模型加载时间: {model_load_time:.3f} 秒")
# 记录单张检测开始时间
detection_start_time = time.time()
image_path = "testimage/微信图片_20251120225838_38.jpg"
image_path = "3.jpg"
# Load an image
image = Image.open(image_path)
inference_state = processor.set_image(image)
# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="building") #building,road and playground building car、people、bicycle
# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# 记录单张检测结束时间
detection_end_time = time.time()
detection_time = detection_end_time - detection_start_time
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"原始检测到 {len(masks)} 个分割结果")
print(f"掩码形状: {masks.shape}")
def calculate_iou(box1, box2):
"""计算两个边界框的IoU(交并比)"""
# 解包坐标
x1_1, y1_1, x1_2, y1_2 = box1
x2_1, y2_1, x2_2, y2_2 = box2
# 计算交集区域
xi1 = max(x1_1, x2_1)
yi1 = max(y1_1, y2_1)
xi2 = min(x1_2, x2_2)
yi2 = min(y1_2, y2_2)
# 计算交集面积
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
# 计算并集面积
box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
union_area = box1_area + box2_area - inter_area
# 避免除以零
if union_area == 0:
return 0.0
#iou_= inter_area / union_area
iou_2= inter_area / box2_area
iou_1= inter_area / box1_area
iou_=max(iou_2,iou_1)
return iou_
def calculate_mask_overlap(mask1, mask2):
"""计算两个掩码的重叠比例(基于mask1)"""
mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
# 计算交集和mask1的面积
intersection = np.logical_and(mask1_np, mask2_np)
mask1_area = np.sum(mask1_np)
if mask1_area == 0:
return 0.0
overlap_ratio = np.sum(intersection) / mask1_area
return overlap_ratio
def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
"""
融合重叠的掩码和边界框
参数:
masks: 形状为 [N, 1, H, W] 的掩码张量
boxes: 形状为 [N, 4] 的边界框张量
scores: 形状为 [N] 的得分张量
iou_threshold: IoU阈值,用于判定边界框是否重叠
overlap_threshold: 掩码重叠阈值,用于判定是否融合
"""
if len(masks) == 0:
return masks, boxes, scores
# 转换为numpy数组进行处理(使用copy()避免负步长问题)
boxes_np = boxes.cpu().numpy().copy()
scores_np = scores.cpu().numpy().copy()
# 按得分降序排序
# 降序索引
sorted_indices = np.argsort(scores_np)[::-1]
# 根据索引重新调整顺序
boxes_sorted = boxes_np[sorted_indices]
scores_sorted = scores_np[sorted_indices]
# 处理masks:先转换为列表,然后按排序索引重新组织
masks_list = [masks[i] for i in range(len(masks))]
# 根据索引重新调整顺序
masks_sorted = [masks_list[i] for i in sorted_indices]
# 初始化保留索引
keep_indices = []
suppressed = set()
for i in range(len(boxes_sorted)):
print('=====================',i)
if i in suppressed:
print('1 跳过',i)
continue
keep_indices.append(i)
for j in range(i + 1, len(boxes_sorted)):
if j in suppressed:
print('2 跳过',i)
continue
# 计算IoU
iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
if iou > iou_threshold:
# 计算掩码重叠比例
#overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
#if overlap_ratio > overlap_threshold:
suppressed.add(j)
print(f"融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
#print(f" - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
else:
#keep_indices.append(i)
print(f"xxxxxx融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
# 获取保留的检测结果
final_indices = [sorted_indices[i] for i in keep_indices]
# 使用PyTorch的索引操作来获取最终结果
final_masks = torch.stack([masks_list[i] for i in final_indices])
final_boxes = boxes[final_indices]
final_scores = scores[final_indices]
print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
return final_masks, final_boxes, final_scores
# 应用融合函数
print("\n开始融合重叠的检测结果...")
fusion_start_time = time.time()
# 调用融合函数
fused_masks, fused_boxes, fused_scores = fuse_overlapping_masks(
masks, boxes, scores,
iou_threshold=0.5, # 可以调整这个阈值
overlap_threshold=0.6 # 可以调整这个阈值
)
fusion_time = time.time() - fusion_start_time
print(f"融合完成时间: {fusion_time:.3f} 秒")
def overlay_masks_with_info(image, masks, boxes, scores, fusion_mode=False):
"""
在图像上叠加掩码,并添加ID、得分和矩形框
masks: 形状为 [N, 1, H, W] 的四维张量
boxes: 形状为 [N, 4] 的边界框张量 [x1, y1, x2, y2]
scores: 形状为 [N] 的得分张量
fusion_mode: 是否为融合后的模式(使用不同颜色)
"""
# 转换为RGB模式以便绘制
image = image.convert("RGB")
draw = ImageDraw.Draw(image)
# 尝试加载字体,如果失败则使用默认字体
try:
# 尝试使用系统中文字体
font = ImageFont.truetype("SimHei.ttf", 20)
except:
try:
font = ImageFont.truetype("Arial.ttf", 20)
except:
font = ImageFont.load_default()
# 将掩码转换为numpy数组并去除通道维度
masks_np = masks.cpu().numpy().astype(np.uint8) # 形状: [N, 1, H, W]
masks_np = masks_np.squeeze(1) # 移除通道维度,形状: [N, H, W]
boxes_np = boxes.cpu().numpy() # 形状: [N, 4]
scores_np = scores.cpu().numpy() # 形状: [N]
n_masks = masks_np.shape[0]
# 根据是否为融合模式选择不同的颜色映射
if fusion_mode:
cmap = plt.cm.get_cmap("viridis", n_masks) # 融合模式使用viridis配色
else:
cmap = plt.cm.get_cmap("rainbow", n_masks) # 原始模式使用rainbow配色
for i, (mask, box, score) in enumerate(zip(masks_np, boxes_np, scores_np)):
# 获取颜色
color = tuple(int(c * 255) for c in cmap(i)[:3])
# 确保掩码是二维的
if mask.ndim == 3:
mask = mask.squeeze(0)
# 创建透明度掩码
alpha_mask = (mask * 128).astype(np.uint8) # 0.5透明度
# 创建彩色覆盖层
overlay = Image.new("RGBA", image.size, color + (128,))
# 应用alpha通道
alpha = Image.fromarray(alpha_mask, mode='L')
overlay.putalpha(alpha)
# 叠加到图像上
image = Image.alpha_composite(image.convert("RGBA"), overlay).convert("RGB")
draw = ImageDraw.Draw(image)
# 绘制边界框
x1, y1, x2, y2 = box
# 确保坐标在图像范围内
x1 = max(0, min(x1, image.width))
y1 = max(0, min(y1, image.height))
x2 = max(0, min(x2, image.width))
y2 = max(0, min(y2, image.height))
# 绘制矩形框
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# 准备文本信息
if fusion_mode:
text = f"Fused-ID:{i} Score:{score:.3f}"
else:
text = f"ID:{i} Score:{score:.3f}"
# 计算文本位置(在框的上方)
try:
# 新版本的PIL
left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
text_width = right - left
text_height = bottom - top
except:
# 旧版本的PIL
text_width, text_height = draw.textsize(text, font=font)
text_x = x1
text_y = max(0, y1 - text_height - 5)
# 绘制文本背景
draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5],
fill=color)
# 绘制文本
draw.text((text_x + 5, text_y + 2), text, fill="white", font=font)
return image
# 记录可视化开始时间
visualization_start_time = time.time()
# 应用掩码叠加(原始结果)
original_image = Image.open(image_path)
result_image_original = overlay_masks_with_info(original_image, masks, boxes, scores, fusion_mode=False)
# 应用掩码叠加(融合后结果)
result_image_fused = overlay_masks_with_info(original_image, fused_masks, fused_boxes, fused_scores, fusion_mode=True)
# 保存结果图像
output_path_original = "segmentation_result_original.png"
output_path_fused = "segmentation_result_fused.png"
result_image_original.save(output_path_original)
result_image_fused.save(output_path_fused)
# 记录可视化结束时间
visualization_end_time = time.time()
visualization_time = visualization_end_time - visualization_start_time
print(f"可视化时间: {visualization_time:.3f} 秒")
print(f"原始分割结果已保存到: {output_path_original}")
print(f"融合后分割结果已保存到: {output_path_fused}")
# 设置中文字体或使用英文避免警告
try:
# 尝试设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
except:
pass
# 显示对比图像
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
# 原始结果
ax1.imshow(result_image_original)
ax1.set_title(f"原始结果: 检测到 {len(masks)} 个分割结果", fontsize=14)
ax1.axis('off')
# 融合后结果
ax2.imshow(result_image_fused)
ax2.set_title(f"融合后结果: 剩余 {len(fused_masks)} 个分割结果", fontsize=14)
ax2.axis('off')
plt.tight_layout()
plt.savefig("segmentation_comparison.png", bbox_inches='tight', dpi=300, facecolor='white')
plt.show()
# 记录总结束时间
total_end_time = time.time()
total_time = total_end_time - total_start_time
# 打印详细的时间统计
print("\n" + "="*50)
print("运行时间统计:")
print("="*50)
print(f"模型加载时间: {model_load_time:.3f} 秒")
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"融合处理时间: {fusion_time:.3f} 秒")
print(f"可视化时间: {visualization_time:.3f} 秒")
print("-"*50)
print(f"总运行时间: {total_time:.3f} 秒")
print("="*50)
def save_mask(masks_to_save, boxes_to_save, scores_to_save, prefix="mask"):
"""保存单个掩码的通用函数"""
print(f"\n保存{prefix}的单个掩码...")
for i, (mask, box, score) in enumerate(zip(masks_to_save, boxes_to_save, scores_to_save)):
# 创建单个掩码的可视化
base_image = Image.open(image_path).convert("RGB")
single_draw = ImageDraw.Draw(base_image)
# 尝试加载字体
try:
single_font = ImageFont.truetype("SimHei.ttf", 24)
except:
try:
single_font = ImageFont.truetype("Arial.ttf", 24)
except:
single_font = ImageFont.load_default()
# 处理掩码
mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
color = tuple(int(c * 255) for c in plt.cm.get_cmap("viridis", len(masks_to_save))(i)[:3])
# 创建透明度掩码
alpha_mask = (mask_np * 128).astype(np.uint8)
overlay = Image.new("RGBA", base_image.size, color + (128,))
alpha = Image.fromarray(alpha_mask, mode='L')
overlay.putalpha(alpha)
base_image = Image.alpha_composite(base_image.convert("RGBA"), overlay).convert("RGB")
single_draw = ImageDraw.Draw(base_image)
# 绘制边界框和文本
x1, y1, x2, y2 = box.cpu().numpy()
single_draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
text = f"ID:{i} Score:{score:.3f}"
try:
# 新版本的PIL
left, top, right, bottom = single_draw.textbbox((0, 0), text, font=single_font)
text_width = right - left
text_height = bottom - top
except:
# 旧版本的PIL
text_width, text_height = single_draw.textsize(text, font=single_font)
text_x = x1
text_y = max(0, y1 - text_height - 5)
single_draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5],
fill=color)
single_draw.text((text_x + 5, text_y + 2), text, fill="white", font=single_font)
base_image.save(f"{prefix}_with_info_{i:02d}.png")
print(f"保存{prefix} {i:02d}.png (得分: {score:.3f})")
# # 保存原始和融合后的单个掩码
# save_mask(masks, boxes, scores, "original_mask")
# save_mask(fused_masks, fused_boxes, fused_scores, "fused_mask")
print("所有处理完成!")
浙公网安备 33010602011771号