import cv2
import torch
import time
import numpy as np
import os
import sys
import sys
sys.path.append('/home/r9000k/v2_project/v5_samyolo/2分割/sam2-main/test')
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 性能优化配置
torch.backends.cudnn.benchmark = True
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# 模型配置
mode_test = 'tiny'
scale = 1.0
model_config = {
"tiny": ("sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"),
"small": ("sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"),
"base": ("sam2.1_hiera_b.yaml", "sam2.1_hiera_base_plus.pt"),
"large": ("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt")
}
model_type, model_path = model_config[mode_test]
checkpoint = f"../checkpoints/{model_path}"
model_cfg = f"../sam2/configs/sam2.1/{model_type}"
print(f"模型配置: {model_cfg}")
print(f"检查点: {checkpoint}")
# 构建模型
print("正在加载SAM2模型...")
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
start_load = time.time()
sam2 = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2)
end_load = time.time()
print(f"模型加载完成,耗时: {end_load - start_load:.2f} 秒")
except ImportError as e:
print(f"导入错误: {e}")
print("请确保sam2模块在Python路径中")
sys.exit(1)
except Exception as e:
print(f"模型加载错误: {e}")
print("创建模拟模式进行测试...")
# 创建模拟模式进行测试
class MockPredictor:
def set_image(self, image):
print("模拟: set_image调用")
def predict(self, **kwargs):
print("模拟: predict调用")
# 返回模拟的mask
mask = np.zeros((480, 640), dtype=bool) # 假设摄像头分辨率
h, w = mask.shape
# 创建一个简单的矩形mask
mask[h//4:3*h//4, w//4:3*w//4] = True
return [mask], [0.95], None
predictor = MockPredictor()
print("使用模拟模式进行测试")
# 清理GPU缓存
if device.type == "cuda":
torch.cuda.empty_cache()
# 交互式选择框
class CameraBoxSelector:
def __init__(self):
self.cap = cv2.VideoCapture(0) # 打开默认摄像头
if not self.cap.isOpened():
raise RuntimeError("无法打开摄像头!")
# 设置摄像头分辨率
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
self.drawing = False
self.current_box = []
self.boxes = [] # 存储所有框
self.current_frame = None
self.result_image = None
self.mask_window_name = "Segmentation Mask"
self.result_window_name = "Segmentation Result"
self.need_update = True
self.monitoring_mode = False # 是否进入实时监测模式
self.last_process_time = 0
self.process_interval = 0.2 # 处理间隔(秒)
# 创建显示窗口
cv2.namedWindow("Camera Feed", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Camera Feed", 800, 600)
cv2.namedWindow(self.mask_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(self.mask_window_name, 800, 600)
cv2.namedWindow(self.result_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(self.result_window_name, 800, 600)
cv2.setMouseCallback("Camera Feed", self.draw_box)
def draw_box(self, event, x, y, flags, param):
if self.monitoring_mode:
return # 监测模式下不响应鼠标事件
if event == cv2.EVENT_LBUTTONDOWN:
print(f"鼠标按下: ({x}, {y})")
self.drawing = True
self.current_box = [x, y, x, y]
self.need_update = True
elif event == cv2.EVENT_MOUSEMOVE:
if self.drawing:
self.current_box[2:] = [x, y]
self.need_update = True
elif event == cv2.EVENT_LBUTTONUP:
print(f"鼠标释放: ({x}, {y})")
self.drawing = False
self.current_box[2:] = [x, y]
# 确保x_min < x_max, y_min < y_max
self.current_box = [min(self.current_box[0], self.current_box[2]),
min(self.current_box[1], self.current_box[3]),
max(self.current_box[0], self.current_box[2]),
max(self.current_box[1], self.current_box[3])]
# 检查框的有效性
box_width = self.current_box[2] - self.current_box[0]
box_height = self.current_box[3] - self.current_box[1]
print(f"添加的框: {self.current_box}, 尺寸: {box_width}x{box_height}")
if box_width >= 5 and box_height >= 5:
self.boxes.append(self.current_box.copy())
print(f"当前框数量: {len(self.boxes)}")
self.current_box = []
self.need_update = True
def update_display(self, frame):
"""更新显示图像"""
display_frame = frame.copy()
if not self.monitoring_mode:
# 绘制所有已保存的框
for i, box in enumerate(self.boxes):
color = (0, 255, 0) if i == len(self.boxes) - 1 else (0, 200, 200)
cv2.rectangle(display_frame, (box[0], box[1]), (box[2], box[3]), color, 2)
cv2.putText(display_frame, str(i+1), (box[0]+5, box[1]+20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
# 绘制当前正在绘制的框
if len(self.current_box) == 4:
cv2.rectangle(display_frame, (self.current_box[0], self.current_box[1]),
(self.current_box[2], self.current_box[3]), (0, 0, 255), 2)
cv2.imshow("Camera Feed", display_frame)
self.need_update = False
def perform_segmentation(self, frame):
if not self.boxes:
print("没有选择任何框,请先选择框")
return None, None
print(f"处理 {len(self.boxes)} 个框的分割...")
start_time = time.time()
try:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
predictor.set_image(frame_rgb)
# 为每个框执行分割
all_masks = []
all_scores = []
for i, box in enumerate(self.boxes):
print(f"处理框 {i+1}: {box}")
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(box),
multimask_output=False
)
if len(masks) > 0:
all_masks.append(masks[0])
all_scores.append(scores[0])
if not all_masks:
print("未生成任何掩码")
return None, None
end_time = time.time()
processing_time = end_time - start_time
print(f"分割完成,耗时: {processing_time:.2f} 秒")
print(f"生成掩码数量: {len(all_masks)}")
# 创建纯mask显示图像
mask_display = np.zeros_like(frame)
# 创建结果图像 - 初始化为原图
result_image = frame.copy()
for i, (mask, score) in enumerate(zip(all_masks, all_scores)):
print(f"处理掩码 {i+1}, 分数: {score:.3f}")
if hasattr(mask, 'cpu'):
mask = mask.cpu().numpy()
if mask.dtype != bool:
mask = mask.astype(bool)
# 生成鲜艳的颜色
color = [
np.random.randint(150, 256),
np.random.randint(150, 256),
np.random.randint(150, 256)
]
# mask透明度
mask_alpha = 0.5 # 50%透明度
border_width = 2 # 边界宽度
# 1. 更新纯mask显示
for c in range(3):
mask_display[:, :, c][mask] = color[c]
# 2. 更新结果图像 - 只在mask区域进行叠加
# 创建彩色掩码
colored_mask = np.zeros_like(result_image)
for c in range(3):
colored_mask[:, :, c][mask] = color[c]
# 创建alpha通道 - mask区域为0.5,其他区域为0
alpha = np.zeros((result_image.shape[0], result_image.shape[1]), dtype=np.float32)
alpha[mask] = mask_alpha
alpha = np.dstack([alpha]*3) # 转换为3通道
# 只在mask区域进行混合
result_image = (result_image * (1 - alpha) + colored_mask * alpha).astype(np.uint8)
# 添加边界
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(result_image, contours, -1, color, border_width)
# 显示分数
y_coords, x_coords = np.where(mask)
if len(x_coords) > 0:
center_x, center_y = np.mean(x_coords), np.mean(y_coords)
cv2.putText(result_image, f"{i+1}:{score:.3f}",
(int(center_x), int(center_y)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
return mask_display, result_image
except Exception as e:
print(f"分割错误: {e}")
import traceback
traceback.print_exc()
return None, None
def run(self):
print("\n=== 实时摄像头交互式分割程序 ===")
print("操作说明:")
print("1. 在'Camera Feed'窗口上按住鼠标左键并拖动选择框")
print("2. 可以连续选择多个框")
print("3. 按 's' 键开始实时监测")
print("4. 按 'c' 键清除所有框并返回选择模式")
print("5. 按 'r' 键重置所有内容")
print("6. 按 ESC 或 'q' 键退出")
print("=" + "="*30)
try:
while True:
ret, frame = self.cap.read()
if not ret:
print("无法从摄像头读取帧")
break
self.current_frame = frame
if self.need_update:
self.update_display(frame)
# 如果在监测模式下,定期处理
if self.monitoring_mode and (time.time() - self.last_process_time > self.process_interval):
mask_display, result_image = self.perform_segmentation(frame)
if mask_display is not None:
cv2.imshow(self.mask_window_name, mask_display)
cv2.imshow(self.result_window_name, result_image)
self.result_image = result_image
self.last_process_time = time.time()
key = cv2.waitKey(1) & 0xFF
if key == 27 or key == ord('q'): # ESC或q键
print("用户请求退出")
break
elif key == ord('s'): # 开始实时监测
if not self.boxes:
print("请先选择至少一个框")
continue
print("进入实时监测模式...")
self.monitoring_mode = True
self.last_process_time = time.time()
# 立即处理一次
mask_display, result_image = self.perform_segmentation(frame)
if mask_display is not None:
cv2.imshow(self.mask_window_name, mask_display)
cv2.imshow(self.result_window_name, result_image)
self.result_image = result_image
elif key == ord('c'): # 清除所有框并返回选择模式
print("清除所有框并返回选择模式")
self.boxes = []
self.current_box = []
self.monitoring_mode = False
self.need_update = True
cv2.imshow(self.mask_window_name, np.zeros_like(frame))
cv2.imshow(self.result_window_name, frame.copy())
elif key == ord('r'): # 重置所有内容
print("重置所有内容")
self.boxes = []
self.current_box = []
self.monitoring_mode = False
self.need_update = True
cv2.imshow("Camera Feed", frame)
cv2.imshow(self.mask_window_name, np.zeros_like(frame))
cv2.imshow(self.result_window_name, frame.copy())
except KeyboardInterrupt:
print("程序被中断")
finally:
self.cap.release()
cv2.destroyAllWindows()
print("程序结束")
# 创建选择器实例并运行
try:
selector = CameraBoxSelector()
selector.run()
except Exception as e:
print(f"初始化错误: {e}")
sys.exit(1)
# 清理内存
if device.type == "cuda":
torch.cuda.empty_cache()