• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
sam3 (2)开发

 

例子2

检测画框,并且合并框,并且合并mask

按照框大小,然后融合重叠的框

 

 没有合并mask前

segmentation_comparison

 合并后

segmentation_comparison

 

  没有合并mask前

segmentation_comparison

合并后

segmentation_comparison

 

  

没有合并mask前

segmentation_comparison

 合并后

 

segmentation_comparison

 

import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time  # 添加时间模块

#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

# 记录总开始时间
total_start_time = time.time()

# 记录模型加载开始时间
model_load_start_time = time.time()

# Load the model
model = build_sam3_image_model(
    checkpoint_path="/home/r9000k/v2_project/sam/sam3/assets/model/sam3.pt"
)
processor = Sam3Processor(model, confidence_threshold=0.5)

# 记录模型加载结束时间
model_load_end_time = time.time()
model_load_time = model_load_end_time - model_load_start_time
print(f"模型加载时间: {model_load_time:.3f} 秒")

# 记录单张检测开始时间
detection_start_time = time.time()

image_path = "testimage/微信图片_20251120225838_38.jpg"
image_path = "3.jpg"
# Load an image
image = Image.open(image_path)
inference_state = processor.set_image(image)

# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="building") #building,road and playground building  car、people、bicycle

# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]

# 记录单张检测结束时间
detection_end_time = time.time()
detection_time = detection_end_time - detection_start_time
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"原始检测到 {len(masks)} 个分割结果")
print(f"掩码形状: {masks.shape}")

def calculate_iou(box1, box2):
    """计算两个边界框的IoU(交并比)"""
    # 解包坐标
    x1_1, y1_1, x1_2, y1_2 = box1
    x2_1, y2_1, x2_2, y2_2 = box2
    
    # 计算交集区域
    xi1 = max(x1_1, x2_1)
    yi1 = max(y1_1, y2_1)
    xi2 = min(x1_2, x2_2)
    yi2 = min(y1_2, y2_2)
    
    # 计算交集面积
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    
    # 计算并集面积
    box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
    box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
    union_area = box1_area + box2_area - inter_area
    
    # 避免除以零
    if union_area == 0:
        return 0.0
    #iou_= inter_area / union_area
    iou_2= inter_area / box2_area
    iou_1= inter_area / box1_area

    iou_=max(iou_2,iou_1)


    return iou_

def calculate_mask_overlap(mask1, mask2):
    """计算两个掩码的重叠比例(基于mask1)"""
    mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
    mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
    
    # 计算交集和mask1的面积
    intersection = np.logical_and(mask1_np, mask2_np)
    mask1_area = np.sum(mask1_np)
    
    if mask1_area == 0:
        return 0.0
    
    overlap_ratio = np.sum(intersection) / mask1_area
    return overlap_ratio

def fuse_overlapping_masks_use_scores(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
    """
    融合重叠的掩码和边界框
    参数:
        masks: 形状为 [N, 1, H, W] 的掩码张量
        boxes: 形状为 [N, 4] 的边界框张量
        scores: 形状为 [N] 的得分张量
        iou_threshold: IoU阈值,用于判定边界框是否重叠
        overlap_threshold: 掩码重叠阈值,用于判定是否融合
    """
    if len(masks) == 0:
        return masks, boxes, scores
    
    # 转换为numpy数组进行处理(使用copy()避免负步长问题)
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
    
    # 按得分降序排序
    # 降序索引
    sorted_indices = np.argsort(scores_np)[::-1]

    # 根据索引重新调整顺序
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    
    # 处理masks:先转换为列表,然后按排序索引重新组织
    masks_list = [masks[i] for i in range(len(masks))]
    # 根据索引重新调整顺序
    masks_sorted = [masks_list[i] for i in sorted_indices]
    
    # 初始化保留索引
    keep_indices = []
    suppressed = set()
    
    for i in range(len(boxes_sorted)):
        print('=====================',i)
        if i in suppressed:
            print('1 跳过',i)
            continue
        
        keep_indices.append(i)
        
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                print('2 跳过',i)
                continue
            
            # 计算IoU
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
            
            if iou > iou_threshold:
                # 计算掩码重叠比例
                #overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
                
                #if overlap_ratio > overlap_threshold:
                suppressed.add(j)
                
                print(f"融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
                #print(f"  - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
            else:
                #keep_indices.append(i)
                print(f"xxxxxx融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
    # 获取保留的检测结果
    final_indices = [sorted_indices[i] for i in keep_indices]
    
    # 使用PyTorch的索引操作来获取最终结果
    final_masks = torch.stack([masks_list[i] for i in final_indices])
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
    
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
    
    return final_masks, final_boxes, final_scores




def fuse_overlapping_masks_justIou(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
    """
    融合重叠的掩码和边界框
    参数:
        masks: 形状为 [N, 1, H, W] 的掩码张量
        boxes: 形状为 [N, 4] 的边界框张量
        scores: 形状为 [N] 的得分张量
        iou_threshold: IoU阈值,用于判定边界框是否重叠
        overlap_threshold: 掩码重叠阈值,用于判定是否融合
    """
    if len(masks) == 0:
        return masks, boxes, scores
    
    # 转换为numpy数组进行处理(使用copy()避免负步长问题)
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
    
    # 计算每个边界框的面积 (w * h)
    areas = []
    for box in boxes_np:
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        area = width * height
        areas.append(area)
    
    areas_np = np.array(areas)
    
    # 按面积降序排序(面积大的优先)
    sorted_indices = np.argsort(areas_np)[::-1]
    
    # 根据索引重新调整顺序
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    areas_sorted = areas_np[sorted_indices]
    
    # 处理masks:先转换为列表,然后按排序索引重新组织
    masks_list = [masks[i] for i in range(len(masks))]
    # 根据索引重新调整顺序
    masks_sorted = [masks_list[i] for i in sorted_indices]
    
    # 初始化保留索引
    keep_indices = []
    suppressed = set()
    
    for i in range(len(boxes_sorted)):
        print('=====================', i)
        if i in suppressed:
            print('1 跳过', i)
            continue
        
        keep_indices.append(i)
        
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                print('2 跳过', j)
                continue
            
            # 计算IoU
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
            
            if iou > iou_threshold:
                # 计算掩码重叠比例
                # overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
                
                # if overlap_ratio > overlap_threshold:
                suppressed.add(j)
                
                print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 覆盖索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", " iou:", iou)
                # print(f"  - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
            else:
                print(f"xxxxxx未融合: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 和索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", " iou:", iou)
    
    # 获取保留的检测结果
    final_indices = [sorted_indices[i] for i in keep_indices]
    
    # 使用PyTorch的索引操作来获取最终结果
    final_masks = torch.stack([masks_list[i] for i in final_indices])
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
    
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
    
    return final_masks, final_boxes, final_scores


def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
    """
    融合重叠的掩码和边界框
    参数:
        masks: 形状为 [N, 1, H, W] 的掩码张量
        boxes: 形状为 [N, 4] 的边界框张量
        scores: 形状为 [N] 的得分张量
        iou_threshold: IoU阈值,用于判定边界框是否重叠
        overlap_threshold: 掩码重叠阈值,用于判定是否融合
    """
    if len(masks) == 0:
        return masks, boxes, scores
    
    # 转换为numpy数组进行处理(使用copy()避免负步长问题)
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
    
    # 计算每个边界框的面积 (w * h)
    areas = []
    for box in boxes_np:
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        area = width * height
        areas.append(area)
    
    areas_np = np.array(areas)
    
    # 按面积降序排序(面积大的优先)
    sorted_indices = np.argsort(areas_np)[::-1]
    
    # 根据索引重新调整顺序
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    areas_sorted = areas_np[sorted_indices]
    
    # 处理masks:先转换为列表,然后按排序索引重新组织
    masks_list = [masks[i] for i in range(len(masks))]
    # 根据索引重新调整顺序
    masks_sorted = [masks_list[i] for i in sorted_indices]
    
    # 初始化保留索引和融合后的masks
    keep_indices = []
    suppressed = set()
    fused_masks = masks_sorted.copy()  # 用于存储融合后的masks
    
    for i in range(len(boxes_sorted)):
        print('=====================', i)
        if i in suppressed:
            print('1 跳过', i)
            continue
        
        keep_indices.append(i)
        current_mask = fused_masks[i]  # 当前要融合的大mask
        
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                print('2 跳过', j)
                continue
            
            # 计算IoU
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
            
            if iou > iou_threshold:


                # 计算掩码重叠比例
                mask_overlap = calculate_mask_overlap(current_mask, masks_sorted[j])


                suppressed.add(j)
                
                # 将小mask合并到大mask上
                fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
                fused_masks[i] = fused_mask
                current_mask = fused_mask  # 更新当前mask
                
                print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", 
                        " iou:", iou, " mask重叠:", mask_overlap)


                
                # if mask_overlap > overlap_threshold:
                #     suppressed.add(j)
                    
                #     # 将小mask合并到大mask上
                #     fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
                #     fused_masks[i] = fused_mask
                #     current_mask = fused_mask  # 更新当前mask
                    
                #     print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})", 
                #           " iou:", iou, " mask重叠:", mask_overlap)
                # else:
                #     print(f"mask重叠不足: 索引 {sorted_indices[i]} 和索引 {sorted_indices[j]}", 
                #           " iou:", iou, " mask重叠:", mask_overlap)
            else:
                print(f"IoU不足: 索引 {sorted_indices[i]} 和索引 {sorted_indices[j]}", 
                      " iou:", iou)
    
    # 获取保留的检测结果(使用融合后的masks)
    final_indices = [sorted_indices[i] for i in keep_indices]
    final_masks_list = [fused_masks[i] for i in keep_indices]
    
    # 使用PyTorch的索引操作来获取最终结果
    final_masks = torch.stack(final_masks_list)
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
    
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
    
    return final_masks, final_boxes, final_scores


def fuse_two_masks(mask1, mask2):
    """
    将两个mask融合(取并集)
    参数:
        mask1: 第一个mask张量 [1, H, W]
        mask2: 第二个mask张量 [1, H, W]
    返回:
        融合后的mask张量 [1, H, W]
    """
    # 使用逻辑或操作合并两个mask
    fused_mask = torch.logical_or(mask1, mask2).float()
    return fused_mask


# 应用融合函数
print("\n开始融合重叠的检测结果...")
fusion_start_time = time.time()

# 调用融合函数
fused_masks, fused_boxes, fused_scores = fuse_overlapping_masks(
    masks, boxes, scores, 
    iou_threshold=0.5,  # 可以调整这个阈值
    overlap_threshold=0.6  # 可以调整这个阈值
)

fusion_time = time.time() - fusion_start_time
print(f"融合完成时间: {fusion_time:.3f} 秒")

def overlay_masks_with_info(image, masks, boxes, scores, fusion_mode=False):
    """
    在图像上叠加掩码,并添加ID、得分和矩形框
    masks: 形状为 [N, 1, H, W] 的四维张量
    boxes: 形状为 [N, 4] 的边界框张量 [x1, y1, x2, y2]
    scores: 形状为 [N] 的得分张量
    fusion_mode: 是否为融合后的模式(使用不同颜色)
    """
    # 转换为RGB模式以便绘制
    image = image.convert("RGB")
    draw = ImageDraw.Draw(image)
    
    # 尝试加载字体,如果失败则使用默认字体
    try:
        # 尝试使用系统中文字体
        font = ImageFont.truetype("SimHei.ttf", 20)
    except:
        try:
            font = ImageFont.truetype("Arial.ttf", 20)
        except:
            font = ImageFont.load_default()
    
    # 将掩码转换为numpy数组并去除通道维度
    masks_np = masks.cpu().numpy().astype(np.uint8)  # 形状: [N, 1, H, W]
    masks_np = masks_np.squeeze(1)  # 移除通道维度,形状: [N, H, W]
    boxes_np = boxes.cpu().numpy()  # 形状: [N, 4]
    scores_np = scores.cpu().numpy()  # 形状: [N]
    
    n_masks = masks_np.shape[0]
    
    # 根据是否为融合模式选择不同的颜色映射
    if fusion_mode:
        cmap = plt.cm.get_cmap("viridis", n_masks)  # 融合模式使用viridis配色
    else:
        cmap = plt.cm.get_cmap("rainbow", n_masks)  # 原始模式使用rainbow配色
    
    for i, (mask, box, score) in enumerate(zip(masks_np, boxes_np, scores_np)):
        # 获取颜色
        color = tuple(int(c * 255) for c in cmap(i)[:3])
        
        # 确保掩码是二维的
        if mask.ndim == 3:
            mask = mask.squeeze(0)
        
        # 创建透明度掩码
        alpha_mask = (mask * 128).astype(np.uint8)  # 0.5透明度
        
        # 创建彩色覆盖层
        overlay = Image.new("RGBA", image.size, color + (128,))
        
        # 应用alpha通道
        alpha = Image.fromarray(alpha_mask, mode='L')
        overlay.putalpha(alpha)
        
        # 叠加到图像上
        image = Image.alpha_composite(image.convert("RGBA"), overlay).convert("RGB")
        draw = ImageDraw.Draw(image)
        
        # 绘制边界框
        x1, y1, x2, y2 = box
        # 确保坐标在图像范围内
        x1 = max(0, min(x1, image.width))
        y1 = max(0, min(y1, image.height))
        x2 = max(0, min(x2, image.width))
        y2 = max(0, min(y2, image.height))
        
        # 绘制矩形框
        draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        
        # 准备文本信息
        if fusion_mode:
            text = f"Fused-ID:{i} Score:{score:.3f}"
        else:
            text = f"ID:{i} Score:{score:.3f}"
        
        # 计算文本位置(在框的上方)
        try:
            # 新版本的PIL
            left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
            text_width = right - left
            text_height = bottom - top
        except:
            # 旧版本的PIL
            text_width, text_height = draw.textsize(text, font=font)
        
        text_x = x1
        text_y = max(0, y1 - text_height - 5)
        
        # 绘制文本背景
        draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], 
                      fill=color)
        
        # 绘制文本
        draw.text((text_x + 5, text_y + 2), text, fill="white", font=font)
    
    return image

# 记录可视化开始时间
visualization_start_time = time.time()

# 应用掩码叠加(原始结果)
original_image = Image.open(image_path)
result_image_original = overlay_masks_with_info(original_image, masks, boxes, scores, fusion_mode=False)

# 应用掩码叠加(融合后结果)
result_image_fused = overlay_masks_with_info(original_image, fused_masks, fused_boxes, fused_scores, fusion_mode=True)

# 保存结果图像
output_path_original = "segmentation_result_original.png"
output_path_fused = "segmentation_result_fused.png"
result_image_original.save(output_path_original)
result_image_fused.save(output_path_fused)

# 记录可视化结束时间
visualization_end_time = time.time()
visualization_time = visualization_end_time - visualization_start_time
print(f"可视化时间: {visualization_time:.3f} 秒")

print(f"原始分割结果已保存到: {output_path_original}")
print(f"融合后分割结果已保存到: {output_path_fused}")

# 设置中文字体或使用英文避免警告
try:
    # 尝试设置中文字体
    plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
    plt.rcParams['axes.unicode_minus'] = False
except:
    pass

# 显示对比图像
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# 原始结果
ax1.imshow(result_image_original)
ax1.set_title(f"原始结果: 检测到 {len(masks)} 个分割结果", fontsize=14)
ax1.axis('off')

# 融合后结果
ax2.imshow(result_image_fused)
ax2.set_title(f"融合后结果: 剩余 {len(fused_masks)} 个分割结果", fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.savefig("segmentation_comparison.png", bbox_inches='tight', dpi=300, facecolor='white')
plt.show()

# 记录总结束时间
total_end_time = time.time()
total_time = total_end_time - total_start_time

# 打印详细的时间统计
print("\n" + "="*50)
print("运行时间统计:")
print("="*50)
print(f"模型加载时间: {model_load_time:.3f} 秒")
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"融合处理时间: {fusion_time:.3f} 秒")
print(f"可视化时间:   {visualization_time:.3f} 秒")
print("-"*50)
print(f"总运行时间:   {total_time:.3f} 秒")
print("="*50)

def save_mask(masks_to_save, boxes_to_save, scores_to_save, prefix="mask"):
    """保存单个掩码的通用函数"""
    print(f"\n保存{prefix}的单个掩码...")
    for i, (mask, box, score) in enumerate(zip(masks_to_save, boxes_to_save, scores_to_save)):
        # 创建单个掩码的可视化
        base_image = Image.open(image_path).convert("RGB")
        single_draw = ImageDraw.Draw(base_image)
        
        # 尝试加载字体
        try:
            single_font = ImageFont.truetype("SimHei.ttf", 24)
        except:
            try:
                single_font = ImageFont.truetype("Arial.ttf", 24)
            except:
                single_font = ImageFont.load_default()
        
        # 处理掩码
        mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
        color = tuple(int(c * 255) for c in plt.cm.get_cmap("viridis", len(masks_to_save))(i)[:3])
        
        # 创建透明度掩码
        alpha_mask = (mask_np * 128).astype(np.uint8)
        overlay = Image.new("RGBA", base_image.size, color + (128,))
        alpha = Image.fromarray(alpha_mask, mode='L')
        overlay.putalpha(alpha)
        base_image = Image.alpha_composite(base_image.convert("RGBA"), overlay).convert("RGB")
        single_draw = ImageDraw.Draw(base_image)
        
        # 绘制边界框和文本
        x1, y1, x2, y2 = box.cpu().numpy()
        single_draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        
        text = f"ID:{i} Score:{score:.3f}"
        try:
            # 新版本的PIL
            left, top, right, bottom = single_draw.textbbox((0, 0), text, font=single_font)
            text_width = right - left
            text_height = bottom - top
        except:
            # 旧版本的PIL
            text_width, text_height = single_draw.textsize(text, font=single_font)
        
        text_x = x1
        text_y = max(0, y1 - text_height - 5)
        
        single_draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], 
                            fill=color)
        single_draw.text((text_x + 5, text_y + 2), text, fill="white", font=single_font)
        
        base_image.save(f"{prefix}_with_info_{i:02d}.png")
        print(f"保存{prefix} {i:02d}.png (得分: {score:.3f})")

# # 保存原始和融合后的单个掩码
# save_mask(masks, boxes, scores, "original_mask")
# save_mask(fused_masks, fused_boxes, fused_scores, "fused_mask")

print("所有处理完成!")

  

 

例子1

检测画框,并且合并

按照分数排序,然后融合重叠的框

缺点 丢失框

segmentation_comparison

 

segmentation_comparison

 

segmentation_comparison

 

import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time  # 添加时间模块

#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

# 记录总开始时间
total_start_time = time.time()

# 记录模型加载开始时间
model_load_start_time = time.time()

# Load the model
model = build_sam3_image_model(
    checkpoint_path="/home/r9000k/v2_project/sam/sam3/assets/model/sam3.pt"
)
processor = Sam3Processor(model, confidence_threshold=0.5)

# 记录模型加载结束时间
model_load_end_time = time.time()
model_load_time = model_load_end_time - model_load_start_time
print(f"模型加载时间: {model_load_time:.3f} 秒")

# 记录单张检测开始时间
detection_start_time = time.time()

image_path = "testimage/微信图片_20251120225838_38.jpg"
image_path = "3.jpg"
# Load an image
image = Image.open(image_path)
inference_state = processor.set_image(image)

# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="building") #building,road and playground building  car、people、bicycle

# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]

# 记录单张检测结束时间
detection_end_time = time.time()
detection_time = detection_end_time - detection_start_time
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"原始检测到 {len(masks)} 个分割结果")
print(f"掩码形状: {masks.shape}")

def calculate_iou(box1, box2):
    """计算两个边界框的IoU(交并比)"""
    # 解包坐标
    x1_1, y1_1, x1_2, y1_2 = box1
    x2_1, y2_1, x2_2, y2_2 = box2
    
    # 计算交集区域
    xi1 = max(x1_1, x2_1)
    yi1 = max(y1_1, y2_1)
    xi2 = min(x1_2, x2_2)
    yi2 = min(y1_2, y2_2)
    
    # 计算交集面积
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    
    # 计算并集面积
    box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
    box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
    union_area = box1_area + box2_area - inter_area
    
    # 避免除以零
    if union_area == 0:
        return 0.0
    #iou_= inter_area / union_area
    iou_2= inter_area / box2_area
    iou_1= inter_area / box1_area

    iou_=max(iou_2,iou_1)


    return iou_

def calculate_mask_overlap(mask1, mask2):
    """计算两个掩码的重叠比例(基于mask1)"""
    mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
    mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
    
    # 计算交集和mask1的面积
    intersection = np.logical_and(mask1_np, mask2_np)
    mask1_area = np.sum(mask1_np)
    
    if mask1_area == 0:
        return 0.0
    
    overlap_ratio = np.sum(intersection) / mask1_area
    return overlap_ratio

def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):
    """
    融合重叠的掩码和边界框
    参数:
        masks: 形状为 [N, 1, H, W] 的掩码张量
        boxes: 形状为 [N, 4] 的边界框张量
        scores: 形状为 [N] 的得分张量
        iou_threshold: IoU阈值,用于判定边界框是否重叠
        overlap_threshold: 掩码重叠阈值,用于判定是否融合
    """
    if len(masks) == 0:
        return masks, boxes, scores
    
    # 转换为numpy数组进行处理(使用copy()避免负步长问题)
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
    
    # 按得分降序排序
    # 降序索引
    sorted_indices = np.argsort(scores_np)[::-1]
    # 根据索引重新调整顺序
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    
    # 处理masks:先转换为列表,然后按排序索引重新组织
    masks_list = [masks[i] for i in range(len(masks))]
    # 根据索引重新调整顺序
    masks_sorted = [masks_list[i] for i in sorted_indices]
    
    # 初始化保留索引
    keep_indices = []
    suppressed = set()
    
    for i in range(len(boxes_sorted)):
        print('=====================',i)
        if i in suppressed:
            print('1 跳过',i)
            continue
        
        keep_indices.append(i)
        
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                print('2 跳过',i)
                continue
            
            # 计算IoU
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
            
            if iou > iou_threshold:
                # 计算掩码重叠比例
                #overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])
                
                #if overlap_ratio > overlap_threshold:
                suppressed.add(j)
                
                print(f"融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
                #print(f"  - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")
            else:
                #keep_indices.append(i)
                print(f"xxxxxx融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)
    # 获取保留的检测结果
    final_indices = [sorted_indices[i] for i in keep_indices]
    
    # 使用PyTorch的索引操作来获取最终结果
    final_masks = torch.stack([masks_list[i] for i in final_indices])
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
    
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
    
    return final_masks, final_boxes, final_scores

# 应用融合函数
print("\n开始融合重叠的检测结果...")
fusion_start_time = time.time()

# 调用融合函数
fused_masks, fused_boxes, fused_scores = fuse_overlapping_masks(
    masks, boxes, scores, 
    iou_threshold=0.5,  # 可以调整这个阈值
    overlap_threshold=0.6  # 可以调整这个阈值
)

fusion_time = time.time() - fusion_start_time
print(f"融合完成时间: {fusion_time:.3f} 秒")

def overlay_masks_with_info(image, masks, boxes, scores, fusion_mode=False):
    """
    在图像上叠加掩码,并添加ID、得分和矩形框
    masks: 形状为 [N, 1, H, W] 的四维张量
    boxes: 形状为 [N, 4] 的边界框张量 [x1, y1, x2, y2]
    scores: 形状为 [N] 的得分张量
    fusion_mode: 是否为融合后的模式(使用不同颜色)
    """
    # 转换为RGB模式以便绘制
    image = image.convert("RGB")
    draw = ImageDraw.Draw(image)
    
    # 尝试加载字体,如果失败则使用默认字体
    try:
        # 尝试使用系统中文字体
        font = ImageFont.truetype("SimHei.ttf", 20)
    except:
        try:
            font = ImageFont.truetype("Arial.ttf", 20)
        except:
            font = ImageFont.load_default()
    
    # 将掩码转换为numpy数组并去除通道维度
    masks_np = masks.cpu().numpy().astype(np.uint8)  # 形状: [N, 1, H, W]
    masks_np = masks_np.squeeze(1)  # 移除通道维度,形状: [N, H, W]
    boxes_np = boxes.cpu().numpy()  # 形状: [N, 4]
    scores_np = scores.cpu().numpy()  # 形状: [N]
    
    n_masks = masks_np.shape[0]
    
    # 根据是否为融合模式选择不同的颜色映射
    if fusion_mode:
        cmap = plt.cm.get_cmap("viridis", n_masks)  # 融合模式使用viridis配色
    else:
        cmap = plt.cm.get_cmap("rainbow", n_masks)  # 原始模式使用rainbow配色
    
    for i, (mask, box, score) in enumerate(zip(masks_np, boxes_np, scores_np)):
        # 获取颜色
        color = tuple(int(c * 255) for c in cmap(i)[:3])
        
        # 确保掩码是二维的
        if mask.ndim == 3:
            mask = mask.squeeze(0)
        
        # 创建透明度掩码
        alpha_mask = (mask * 128).astype(np.uint8)  # 0.5透明度
        
        # 创建彩色覆盖层
        overlay = Image.new("RGBA", image.size, color + (128,))
        
        # 应用alpha通道
        alpha = Image.fromarray(alpha_mask, mode='L')
        overlay.putalpha(alpha)
        
        # 叠加到图像上
        image = Image.alpha_composite(image.convert("RGBA"), overlay).convert("RGB")
        draw = ImageDraw.Draw(image)
        
        # 绘制边界框
        x1, y1, x2, y2 = box
        # 确保坐标在图像范围内
        x1 = max(0, min(x1, image.width))
        y1 = max(0, min(y1, image.height))
        x2 = max(0, min(x2, image.width))
        y2 = max(0, min(y2, image.height))
        
        # 绘制矩形框
        draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        
        # 准备文本信息
        if fusion_mode:
            text = f"Fused-ID:{i} Score:{score:.3f}"
        else:
            text = f"ID:{i} Score:{score:.3f}"
        
        # 计算文本位置(在框的上方)
        try:
            # 新版本的PIL
            left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
            text_width = right - left
            text_height = bottom - top
        except:
            # 旧版本的PIL
            text_width, text_height = draw.textsize(text, font=font)
        
        text_x = x1
        text_y = max(0, y1 - text_height - 5)
        
        # 绘制文本背景
        draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], 
                      fill=color)
        
        # 绘制文本
        draw.text((text_x + 5, text_y + 2), text, fill="white", font=font)
    
    return image

# 记录可视化开始时间
visualization_start_time = time.time()

# 应用掩码叠加(原始结果)
original_image = Image.open(image_path)
result_image_original = overlay_masks_with_info(original_image, masks, boxes, scores, fusion_mode=False)

# 应用掩码叠加(融合后结果)
result_image_fused = overlay_masks_with_info(original_image, fused_masks, fused_boxes, fused_scores, fusion_mode=True)

# 保存结果图像
output_path_original = "segmentation_result_original.png"
output_path_fused = "segmentation_result_fused.png"
result_image_original.save(output_path_original)
result_image_fused.save(output_path_fused)

# 记录可视化结束时间
visualization_end_time = time.time()
visualization_time = visualization_end_time - visualization_start_time
print(f"可视化时间: {visualization_time:.3f} 秒")

print(f"原始分割结果已保存到: {output_path_original}")
print(f"融合后分割结果已保存到: {output_path_fused}")

# 设置中文字体或使用英文避免警告
try:
    # 尝试设置中文字体
    plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
    plt.rcParams['axes.unicode_minus'] = False
except:
    pass

# 显示对比图像
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# 原始结果
ax1.imshow(result_image_original)
ax1.set_title(f"原始结果: 检测到 {len(masks)} 个分割结果", fontsize=14)
ax1.axis('off')

# 融合后结果
ax2.imshow(result_image_fused)
ax2.set_title(f"融合后结果: 剩余 {len(fused_masks)} 个分割结果", fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.savefig("segmentation_comparison.png", bbox_inches='tight', dpi=300, facecolor='white')
plt.show()

# 记录总结束时间
total_end_time = time.time()
total_time = total_end_time - total_start_time

# 打印详细的时间统计
print("\n" + "="*50)
print("运行时间统计:")
print("="*50)
print(f"模型加载时间: {model_load_time:.3f} 秒")
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"融合处理时间: {fusion_time:.3f} 秒")
print(f"可视化时间:   {visualization_time:.3f} 秒")
print("-"*50)
print(f"总运行时间:   {total_time:.3f} 秒")
print("="*50)

def save_mask(masks_to_save, boxes_to_save, scores_to_save, prefix="mask"):
    """保存单个掩码的通用函数"""
    print(f"\n保存{prefix}的单个掩码...")
    for i, (mask, box, score) in enumerate(zip(masks_to_save, boxes_to_save, scores_to_save)):
        # 创建单个掩码的可视化
        base_image = Image.open(image_path).convert("RGB")
        single_draw = ImageDraw.Draw(base_image)
        
        # 尝试加载字体
        try:
            single_font = ImageFont.truetype("SimHei.ttf", 24)
        except:
            try:
                single_font = ImageFont.truetype("Arial.ttf", 24)
            except:
                single_font = ImageFont.load_default()
        
        # 处理掩码
        mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
        color = tuple(int(c * 255) for c in plt.cm.get_cmap("viridis", len(masks_to_save))(i)[:3])
        
        # 创建透明度掩码
        alpha_mask = (mask_np * 128).astype(np.uint8)
        overlay = Image.new("RGBA", base_image.size, color + (128,))
        alpha = Image.fromarray(alpha_mask, mode='L')
        overlay.putalpha(alpha)
        base_image = Image.alpha_composite(base_image.convert("RGBA"), overlay).convert("RGB")
        single_draw = ImageDraw.Draw(base_image)
        
        # 绘制边界框和文本
        x1, y1, x2, y2 = box.cpu().numpy()
        single_draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        
        text = f"ID:{i} Score:{score:.3f}"
        try:
            # 新版本的PIL
            left, top, right, bottom = single_draw.textbbox((0, 0), text, font=single_font)
            text_width = right - left
            text_height = bottom - top
        except:
            # 旧版本的PIL
            text_width, text_height = single_draw.textsize(text, font=single_font)
        
        text_x = x1
        text_y = max(0, y1 - text_height - 5)
        
        single_draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], 
                            fill=color)
        single_draw.text((text_x + 5, text_y + 2), text, fill="white", font=single_font)
        
        base_image.save(f"{prefix}_with_info_{i:02d}.png")
        print(f"保存{prefix} {i:02d}.png (得分: {score:.3f})")

# # 保存原始和融合后的单个掩码
# save_mask(masks, boxes, scores, "original_mask")
# save_mask(fused_masks, fused_boxes, fused_scores, "fused_mask")

print("所有处理完成!")

  

posted on 2025-11-21 23:38  MKT-porter  阅读(1)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3