• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
SAM2 图像分割(2)鼠标单个框选位置 实时分割显示

image

 

import cv2
import torch
import time
import numpy as np
import os
import sys

# 检查文件是否存在
image_path = "npu2pm.JPG"
if not os.path.exists(image_path):
    print(f"错误:图像文件 '{image_path}' 不存在!")
    print("请确保图像文件在当前目录下")
    sys.exit(1)

print("图像文件存在,继续执行...")

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 性能优化配置
torch.backends.cudnn.benchmark = True
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# 图像和模型配置
mode_test = 'tiny'
scale = 1.0

model_config = {
    "tiny": ("sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"),
    "small": ("sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"),
    "base": ("sam2.1_hiera_b.yaml", "sam2.1_hiera_base_plus.pt"),
    "large": ("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt")
}

model_type, model_path = model_config[mode_test]
checkpoint = f"../checkpoints/{model_path}"
model_cfg = f"../sam2/configs/sam2.1/{model_type}"

print(f"模型配置: {model_cfg}")
print(f"检查点: {checkpoint}")

# 检查模型文件是否存在
if not os.path.exists(checkpoint.replace("../checkpoints/", "")) and not os.path.exists(checkpoint):
    print(f"警告:模型文件可能不存在于: {checkpoint}")

# 加载并预处理图像
print("正在加载图像...")
image_cv = cv2.imread(image_path)
if image_cv is None:
    raise ValueError("无法加载图像!")

height, width = image_cv.shape[:2]
print(f"原始图像尺寸: {width}x{height}")

new_width = int(width * scale)
new_height = int(height * scale)

image = cv2.resize(image_cv, (new_width, new_height))
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

print(f"调整后图像尺寸: {new_width}x{new_height}")

# 构建模型
print("正在加载SAM2模型...")
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    start_load = time.time()
    sam2 = build_sam2(model_cfg, checkpoint, device=device)
    predictor = SAM2ImagePredictor(sam2)
    end_load = time.time()
    print(f"模型加载完成,耗时: {end_load - start_load:.2f} 秒")
    
except ImportError as e:
    print(f"导入错误: {e}")
    print("请确保sam2模块在Python路径中")
    sys.exit(1)
except Exception as e:
    print(f"模型加载错误: {e}")
    print("创建模拟模式进行测试...")
    # 创建模拟模式进行测试
    class MockPredictor:
        def set_image(self, image):
            print("模拟: set_image调用")
            
        def predict(self, **kwargs):
            print("模拟: predict调用")
            # 返回模拟的mask
            mask = np.zeros((new_height, new_width), dtype=bool)
            h, w = mask.shape
            # 创建一个简单的矩形mask
            mask[h//4:3*h//4, w//4:3*w//4] = True
            return [mask], [0.95], None
    
    predictor = MockPredictor()
    print("使用模拟模式进行测试")

# 清理GPU缓存
if device.type == "cuda":
    torch.cuda.empty_cache()

# 交互式选择框
class BoxSelector:
    def __init__(self, image):
        self.image = image
        self.drawing = False
        self.box = []
        self.temp_image = image.copy()
        self.result_image = None
        self.mask_window_name = "Segmentation Mask"
        self.result_window_name = "Segmentation Result"
        
        # 创建显示窗口
        cv2.namedWindow("Select Box", cv2.WINDOW_NORMAL)
        cv2.resizeWindow("Select Box", 800, 600)
        cv2.namedWindow(self.mask_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.mask_window_name, 800, 600)
        cv2.namedWindow(self.result_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.result_window_name, 800, 600)
        
        cv2.setMouseCallback("Select Box", self.draw_box)
        
    def draw_box(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            print(f"鼠标按下: ({x}, {y})")
            self.drawing = True
            self.box = [x, y, x, y]
            self.update_display()
            
        elif event == cv2.EVENT_MOUSEMOVE:
            if self.drawing:
                self.box[2:] = [x, y]
                self.update_display()
                
        elif event == cv2.EVENT_LBUTTONUP:
            print(f"鼠标释放: ({x}, {y})")
            self.drawing = False
            self.box[2:] = [x, y]
            # 确保x_min < x_max, y_min < y_max
            self.box = [min(self.box[0], self.box[2]), min(self.box[1], self.box[3]), 
                       max(self.box[0], self.box[2]), max(self.box[1], self.box[3])]
            
            # 检查框的有效性
            box_width = self.box[2] - self.box[0]
            box_height = self.box[3] - self.box[1]
            
            print(f"选择的框: {self.box}, 尺寸: {box_width}x{box_height}")
            
            if box_width < 5 or box_height < 5:
                print("框太小,请重新选择")
                self.temp_image = self.image.copy()
                cv2.imshow("Select Box", self.temp_image)
                return
                
            self.update_display()
            print("开始分割处理...")
            
            # 执行分割
            self.perform_segmentation(self.box)
    
    def update_display(self):
        """更新显示图像"""
        self.temp_image = self.image.copy()
        if len(self.box) == 4:
            cv2.rectangle(self.temp_image, (self.box[0], self.box[1]), 
                         (self.box[2], self.box[3]), (0, 255, 0), 2)
        cv2.imshow("Select Box", self.temp_image)

    def perform_segmentation(self, box):
        print("执行分割...")
        start_time = time.time()

        try:
            predictor.set_image(image_rgb)
            
            masks, scores, logits = predictor.predict(
                point_coords=None,
                point_labels=None,
                box=np.array(box),
                multimask_output=False
            )

            end_time = time.time()
            processing_time = end_time - start_time

            print(f"分割完成,耗时: {processing_time:.2f} 秒")
            print(f"生成掩码数量: {len(masks)}")

            if len(masks) > 0:
                # 创建基础图像
                base_image = self.image.copy()
                cv2.rectangle(base_image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
                
                # 创建纯mask显示图像
                mask_display = np.zeros_like(base_image)
                
                for i, mask in enumerate(masks):
                    print(f"处理掩码 {i+1}, 分数: {scores[i]:.3f}")
                    
                    if hasattr(mask, 'cpu'):
                        mask = mask.cpu().numpy()
                    
                    if mask.dtype != bool:
                        mask = mask.astype(bool)
                    
                    # 生成鲜艳的颜色
                    color = [
                        np.random.randint(150, 256),
                        np.random.randint(150, 256),
                        np.random.randint(150, 256)
                    ]
                    
                    # 增强mask显示效果
                    mask_alpha = 0.7  # mask透明度
                    border_width = 2   # 边界宽度
                    
                    # 1. 创建纯mask显示
                    mask_only = np.zeros_like(base_image)
                    for c in range(3):
                        mask_only[:, :, c][mask] = color[c]
                    cv2.imshow(self.mask_window_name, mask_only)
                    
                    # 2. 创建结果图像(带原图)
                    result_image = base_image.copy()
                    
                    # 创建彩色掩码
                    colored_mask = np.zeros_like(result_image)
                    for c in range(3):
                        colored_mask[:, :, c][mask] = color[c]
                    
                    # 混合图像(增强mask效果)
                    result_image = cv2.addWeighted(result_image, 1 - mask_alpha, colored_mask, mask_alpha, 0)
                    
                    # 添加边界
                    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    cv2.drawContours(result_image, contours, -1, color, border_width)
                    
                    # 显示分数
                    y_coords, x_coords = np.where(mask)
                    if len(x_coords) > 0:
                        center_x, center_y = np.mean(x_coords), np.mean(y_coords)
                        cv2.putText(result_image, f"Score: {scores[i]:.3f}", 
                                   (int(center_x), int(center_y)), 
                                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
                    
                    # 显示结果
                    cv2.imshow(self.result_window_name, result_image)
                    self.result_image = result_image

                # 保存结果
                result_path = "box_result_interactive.jpg"
                cv2.imwrite(result_path, self.result_image)
                print(f"结果已保存: {result_path}")
                
            else:
                print("未生成掩码")
                # 显示空白结果
                cv2.imshow(self.mask_window_name, np.zeros_like(self.image))
                cv2.imshow(self.result_window_name, self.image.copy())
                
        except Exception as e:
            print(f"分割错误: {e}")
            import traceback
            traceback.print_exc()
            # 显示错误信息
            error_img = np.zeros_like(self.image)
            cv2.putText(error_img, f"Error: {str(e)}", (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
            cv2.imshow(self.result_window_name, error_img)

# 创建选择器实例
selector = BoxSelector(image)

print("\n=== 交互式分割程序 ===")
print("操作说明:")
print("1. 在'Select Box'窗口上按住鼠标左键并拖动选择框")
print("2. 释放鼠标后自动执行分割")
print("3. 'Segmentation Mask'窗口显示纯mask效果")
print("4. 'Segmentation Result'窗口显示混合结果")
print("5. 按 'r' 键重置选择")
print("6. 按 ESC 或 'q' 键退出")
print("=" + "="*30)

# 初始显示
cv2.imshow("Select Box", image)
cv2.imshow(selector.mask_window_name, np.zeros_like(image))
cv2.imshow(selector.result_window_name, image.copy())
print("窗口已显示,等待用户输入...")

# 主循环
try:
    while True:
        key = cv2.waitKey(100) & 0xFF  # 100ms延迟,减少CPU使用
        
        if key == 27 or key == ord('q'):  # ESC或q键
            print("用户请求退出")
            break
        elif key == ord('r'):  # 重置键
            print("重置选择")
            selector.box = []
            selector.temp_image = image.copy()
            cv2.imshow("Select Box", selector.temp_image)
            cv2.imshow(selector.mask_window_name, np.zeros_like(image))
            cv2.imshow(selector.result_window_name, image.copy())
            
except KeyboardInterrupt:
    print("程序被中断")

finally:
    cv2.destroyAllWindows()
    print("程序结束")

# 清理内存
if device.type == "cuda":
    torch.cuda.empty_cache()

  

鼠标选取多个框

image

 

image

 

'''

使用方法:
在"Select Box"窗口上拖动鼠标选择多个区域
按's'键执行分割
查看"Segmentation Mask"和"Segmentation Result"窗口
按'c'键可以清除所有框重新选择
按'r'键完全重置
按ESC或'q'键退出

'''


import cv2
import torch
import time
import numpy as np
import os
import sys

# 检查文件是否存在
image_path = "npu2pm.JPG"
if not os.path.exists(image_path):
    print(f"错误:图像文件 '{image_path}' 不存在!")
    print("请确保图像文件在当前目录下")
    sys.exit(1)

print("图像文件存在,继续执行...")

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 性能优化配置
torch.backends.cudnn.benchmark = True
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# 图像和模型配置
mode_test = 'tiny'
scale = 1.0

model_config = {
    "tiny": ("sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"),
    "small": ("sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"),
    "base": ("sam2.1_hiera_b.yaml", "sam2.1_hiera_base_plus.pt"),
    "large": ("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt")
}

model_type, model_path = model_config[mode_test]
checkpoint = f"../checkpoints/{model_path}"
model_cfg = f"../sam2/configs/sam2.1/{model_type}"

print(f"模型配置: {model_cfg}")
print(f"检查点: {checkpoint}")

# 检查模型文件是否存在
if not os.path.exists(checkpoint.replace("../checkpoints/", "")) and not os.path.exists(checkpoint):
    print(f"警告:模型文件可能不存在于: {checkpoint}")

# 加载并预处理图像
print("正在加载图像...")
image_cv = cv2.imread(image_path)
if image_cv is None:
    raise ValueError("无法加载图像!")

height, width = image_cv.shape[:2]
print(f"原始图像尺寸: {width}x{height}")

new_width = int(width * scale)
new_height = int(height * scale)

image = cv2.resize(image_cv, (new_width, new_height))
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

print(f"调整后图像尺寸: {new_width}x{new_height}")

# 构建模型
print("正在加载SAM2模型...")
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    start_load = time.time()
    sam2 = build_sam2(model_cfg, checkpoint, device=device)
    predictor = SAM2ImagePredictor(sam2)
    end_load = time.time()
    print(f"模型加载完成,耗时: {end_load - start_load:.2f} 秒")
    
except ImportError as e:
    print(f"导入错误: {e}")
    print("请确保sam2模块在Python路径中")
    sys.exit(1)
except Exception as e:
    print(f"模型加载错误: {e}")
    print("创建模拟模式进行测试...")
    # 创建模拟模式进行测试
    class MockPredictor:
        def set_image(self, image):
            print("模拟: set_image调用")
            
        def predict(self, **kwargs):
            print("模拟: predict调用")
            # 返回模拟的mask
            mask = np.zeros((new_height, new_width), dtype=bool)
            h, w = mask.shape
            # 创建一个简单的矩形mask
            mask[h//4:3*h//4, w//4:3*w//4] = True
            return [mask], [0.95], None
    
    predictor = MockPredictor()
    print("使用模拟模式进行测试")

# 清理GPU缓存
if device.type == "cuda":
    torch.cuda.empty_cache()

# 交互式选择框
class BoxSelector:
    def __init__(self, image):
        self.image = image
        self.drawing = False
        self.current_box = []
        self.boxes = []  # 存储所有框
        self.temp_image = image.copy()
        self.result_image = None
        self.mask_window_name = "Segmentation Mask"
        self.result_window_name = "Segmentation Result"
        
        # 创建显示窗口
        cv2.namedWindow("Select Box", cv2.WINDOW_NORMAL)
        cv2.resizeWindow("Select Box", 800, 600)
        cv2.namedWindow(self.mask_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.mask_window_name, 800, 600)
        cv2.namedWindow(self.result_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.result_window_name, 800, 600)
        
        cv2.setMouseCallback("Select Box", self.draw_box)
        
    def draw_box(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            print(f"鼠标按下: ({x}, {y})")
            self.drawing = True
            self.current_box = [x, y, x, y]
            self.update_display()
            
        elif event == cv2.EVENT_MOUSEMOVE:
            if self.drawing:
                self.current_box[2:] = [x, y]
                self.update_display()
                
        elif event == cv2.EVENT_LBUTTONUP:
            print(f"鼠标释放: ({x}, {y})")
            self.drawing = False
            self.current_box[2:] = [x, y]
            # 确保x_min < x_max, y_min < y_max
            self.current_box = [min(self.current_box[0], self.current_box[2]), 
                               min(self.current_box[1], self.current_box[3]), 
                               max(self.current_box[0], self.current_box[2]), 
                               max(self.current_box[1], self.current_box[3])]
            
            # 检查框的有效性
            box_width = self.current_box[2] - self.current_box[0]
            box_height = self.current_box[3] - self.current_box[1]
            
            print(f"添加的框: {self.current_box}, 尺寸: {box_width}x{box_height}")
            
            if box_width >= 5 and box_height >= 5:
                self.boxes.append(self.current_box.copy())
                print(f"当前框数量: {len(self.boxes)}")
            
            self.current_box = []
            self.update_display()
    
    def update_display(self):
        """更新显示图像"""
        self.temp_image = self.image.copy()
        
        # 绘制所有已保存的框
        for i, box in enumerate(self.boxes):
            color = (0, 255, 0) if i == len(self.boxes) - 1 else (0, 200, 200)
            cv2.rectangle(self.temp_image, (box[0], box[1]), (box[2], box[3]), color, 2)
            cv2.putText(self.temp_image, str(i+1), (box[0]+5, box[1]+20), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
        
        # 绘制当前正在绘制的框
        if len(self.current_box) == 4:
            cv2.rectangle(self.temp_image, (self.current_box[0], self.current_box[1]), 
                         (self.current_box[2], self.current_box[3]), (0, 0, 255), 2)
        
        cv2.imshow("Select Box", self.temp_image)

    def perform_segmentation(self):
        if not self.boxes:
            print("没有选择任何框,请先选择框")
            return
            
        print(f"开始处理 {len(self.boxes)} 个框的分割...")
        start_time = time.time()

        try:
            predictor.set_image(image_rgb)
            
            # 为每个框执行分割
            all_masks = []
            all_scores = []
            
            for i, box in enumerate(self.boxes):
                print(f"处理框 {i+1}: {box}")
                masks, scores, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=np.array(box),
                    multimask_output=False
                )
                
                if len(masks) > 0:
                    all_masks.append(masks[0])
                    all_scores.append(scores[0])
            
            if not all_masks:
                print("未生成任何掩码")
                return
                
            end_time = time.time()
            processing_time = end_time - start_time

            print(f"分割完成,耗时: {processing_time:.2f} 秒")
            print(f"生成掩码数量: {len(all_masks)}")

            # 创建基础图像
            base_image = self.image.copy()
            
            # 绘制所有框
            for i, box in enumerate(self.boxes):
                color = (0, 255, 0) if i == len(self.boxes) - 1 else (0, 200, 200)
                cv2.rectangle(base_image, (box[0], box[1]), (box[2], box[3]), color, 2)
                cv2.putText(base_image, str(i+1), (box[0]+5, box[1]+20), 
                          cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            
            # 创建纯mask显示图像
            mask_display = np.zeros_like(base_image)
            
            # 创建结果图像 - 初始化为原图
            result_image = base_image.copy()
            
            for i, (mask, score) in enumerate(zip(all_masks, all_scores)):
                print(f"处理掩码 {i+1}, 分数: {score:.3f}")
                
                if hasattr(mask, 'cpu'):
                    mask = mask.cpu().numpy()
                
                if mask.dtype != bool:
                    mask = mask.astype(bool)
                
                # 生成鲜艳的颜色
                color = [
                    np.random.randint(150, 256),
                    np.random.randint(150, 256),
                    np.random.randint(150, 256)
                ]
                
                # mask透明度
                mask_alpha = 0.5  # 50%透明度
                border_width = 2   # 边界宽度
                
                # 1. 更新纯mask显示
                for c in range(3):
                    mask_display[:, :, c][mask] = color[c]
                
                # 2. 更新结果图像 - 只在mask区域进行叠加
                # 创建彩色掩码
                colored_mask = np.zeros_like(result_image)
                for c in range(3):
                    colored_mask[:, :, c][mask] = color[c]
                
                # 创建alpha通道 - mask区域为0.5,其他区域为0
                alpha = np.zeros((result_image.shape[0], result_image.shape[1]), dtype=np.float32)
                alpha[mask] = mask_alpha
                alpha = np.dstack([alpha]*3)  # 转换为3通道
                
                # 只在mask区域进行混合
                result_image = (result_image * (1 - alpha) + colored_mask * alpha).astype(np.uint8)
                
                # 添加边界
                contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(result_image, contours, -1, color, border_width)
                
                # 显示分数
                y_coords, x_coords = np.where(mask)
                if len(x_coords) > 0:
                    center_x, center_y = np.mean(x_coords), np.mean(y_coords)
                    cv2.putText(result_image, f"{i+1}:{score:.3f}", 
                               (int(center_x), int(center_y)), 
                               cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            
            # 显示结果
            cv2.imshow(self.mask_window_name, mask_display)
            cv2.imshow(self.result_window_name, result_image)
            self.result_image = result_image

            # 保存结果
            result_path = "multi_box_result.jpg"
            cv2.imwrite(result_path, self.result_image)
            print(f"结果已保存: {result_path}")
                
        except Exception as e:
            print(f"分割错误: {e}")
            import traceback
            traceback.print_exc()
            # 显示错误信息
            error_img = np.zeros_like(self.image)
            cv2.putText(error_img, f"Error: {str(e)}", (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
            cv2.imshow(self.result_window_name, error_img)

# 创建选择器实例
selector = BoxSelector(image)

print("\n=== 交互式多框分割程序 ===")
print("操作说明:")
print("1. 在'Select Box'窗口上按住鼠标左键并拖动选择框")
print("2. 可以连续选择多个框")
print("3. 按 's' 键执行分割")
print("4. 按 'c' 键清除所有框")
print("5. 按 'r' 键重置所有内容")
print("6. 按 ESC 或 'q' 键退出")
print("=" + "="*30)

# 初始显示
cv2.imshow("Select Box", image)
cv2.imshow(selector.mask_window_name, np.zeros_like(image))
cv2.imshow(selector.result_window_name, image.copy())
print("窗口已显示,等待用户输入...")

# 主循环
try:
    while True:
        key = cv2.waitKey(100) & 0xFF  # 100ms延迟,减少CPU使用
        
        if key == 27 or key == ord('q'):  # ESC或q键
            print("用户请求退出")
            break
        elif key == ord('s'):  # 执行分割
            print("执行分割...")
            selector.perform_segmentation()
        elif key == ord('c'):  # 清除所有框
            print("清除所有框")
            selector.boxes = []
            selector.current_box = []
            selector.update_display()
            cv2.imshow(selector.mask_window_name, np.zeros_like(image))
            cv2.imshow(selector.result_window_name, image.copy())
        elif key == ord('r'):  # 重置所有内容
            print("重置所有内容")
            selector.boxes = []
            selector.current_box = []
            selector.temp_image = image.copy()
            selector.result_image = None
            cv2.imshow("Select Box", image)
            cv2.imshow(selector.mask_window_name, np.zeros_like(image))
            cv2.imshow(selector.result_window_name, image.copy())
            
except KeyboardInterrupt:
    print("程序被中断")

finally:
    cv2.destroyAllWindows()
    print("程序结束")

# 清理内存
if device.type == "cuda":
    torch.cuda.empty_cache()

  

 

posted on 2025-10-26 07:11  MKT-porter  阅读(1)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3