• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
GroundingDino目标跟踪+sam2分割

说明

1 sam2的代码文件依赖是相对于sam2工程的,所以在sam2下面构建

2 需要训练空中数据集

3 不能直接全直接sam2分割,分割的会很混乱

 

安装

环境 rtx 3070  ubuntu20  cuda11.8

python3.10

1安装 GroundingDino

2安装sam2

3 以sma2为根目录创建新工程,创建这个代码,将GroundingDino工程代码拷贝过来或者路径引用

反过来使用全局路径引用sam2不行,因为他是安装自己内部路径相对位置找的依赖文件

image

 

 

 

 

image

 

image

 

image

 

image

 

 

image

 

import os
import sys
import time
import warnings
import numpy as np
import torch
import cv2
from PIL import Image, ImageDraw, ImageFont
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.vl_utils import create_positive_map_from_span
from groundingdino.util.inference import load_model, load_image, predict, annotate

# 配置警告过滤器
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


from API_SAM2 import *
# import sys
# path_sam2= '/home/r9000k/v2_project/v5_samyolo/2分割/sam2-main/test'
# sys.path.append(path_sam2)

selector = CameraBoxSelector()




# 无人机对地检测目标列表(排除树木和人)
DRONE_TARGETS = [
    # 车辆类
    "vehicle", "car", "truck", "bus", "van", "SUV", "motorcycle", "bicycle", 
    "construction vehicle", "excavator", "bulldozer", "crane", "forklift",
    "tractor", "trailer", "ambulance", "fire truck", "police car",
    
    # 建筑物和结构
    "building", "house", "apartment", "commercial building", "factory", 
    "warehouse", "shed", "garage", "roof", "chimney",
    "bridge", "overpass", "tunnel", "dam", "power plant",
    
    # 道路和交通设施
    "road", "highway", "street", "pavement", "crosswalk", "roundabout",
    "traffic light", "street light", "road sign", "billboard",
    "parking lot", "gas station", "bus stop",
    
    # 水域相关
    "river", "lake", "pond", "reservoir", "swimming pool", "fountain",
    "boat", "ship", "yacht", "speedboat", "dock", "pier", "harbor",
    
    # 农业相关
    "farmland", "crop field", "greenhouse", "barn", "silo", "windmill",
    "irrigation system", "livestock pen",
    
    # 能源设施
    "solar panel", "wind turbine", "power line", "transformer", 
    "oil rig", "oil tank", "gas pipeline",
    
    # 运动场地
    "soccer field", "basketball court", "tennis court", "baseball field",
    "swimming pool", "stadium", "running track",
    
    # 基础设施
    "airport", "runway", "hangar", "airplane", "helicopter",
    "railway", "train", "railroad track", "train station",
    "cell tower", "communication tower", "satellite dish",
    
    # 军事和安全设施(可选)
    "military vehicle", "barracks", "checkpoint", "fence", "gate",
    
    # 其他重要目标
    "container", "shipping container", "cargo", "construction material",
    "playground equipment", "park bench", "statue", "monument"
]


DRONE_TARGETS_min = [

    # 建筑物和结构
    "building", "house", "apartment", "commercial building", "factory", 
    "warehouse", "shed", "garage", "roof", "chimney",
    "bridge", "overpass",  
    "gray building","white building","large building","red playground","dark brown building",

    'car',
    
    # 道路和交通设施
    "road", "highway", "street", "pavement", "crosswalk", 
    
    # 运动场地
    "soccer field", "basketball court",

]





class Config:
    def __init__(self):
        # 模型配置
        self.model_type = "SwinB"  # "SwinB" 938mb 或 "SwinT" 600mb
        #self.text_prompt = "building, person, door, cap"  # 检测文本提示
        self.text_prompt = ", ".join(DRONE_TARGETS_min)
        '''
        官方
        BOX_TRESHOLD = 0.35
        TEXT_TRESHOLD = 0.25
        '''

        self.box_threshold = 0.2   # 提高框阈值,减少误检
        self.text_threshold = 0.2 # 降低文本阈值,提高小目标召回
        self.cpu_only = False  # 仅使用CPU运行
        
        # 输入源配置
        self.input_type = "folder"  # "video"或"folder"
        self.video_path = 0  # 视频路径或摄像头ID
        self.folder_path = "/media/r9000k/DD_XS/2数据/2RTK/data_4_city/460_500/images"  # 图像文件夹路径
        #/media/r9000k/DD_XS/2数据/2RTK/data_4_city/300_map_2pm/images
        #"/home/r9000k/v0_data/rtk/nwpu_1130_12pm"  # 图像文件夹路径
        self.img_scale=1# 图像缩放系数 
        
        # 输出配置
        self.output_dir = "outputs"  # 输出目录
        self.save_results = True  # 是否保存结果
        self.show_results = True  # 是否显示结果
        
        # 其他配置
         # 后处理配置
        self.min_target_area = 0  # 最小目标面积(像素),过滤过小目标
        self.sort_by_timestamp = True  # 是否按时间戳排序图像

def plot_boxes_to_image_cv2(image_cv2, boxes, labels):
    """
    在OpenCV图像上绘制检测框和标签
    """
    H, W = image_cv2.shape[:2]

    opencv_boxes=[]
    
    for box, label in zip(boxes, labels):
        # 从0..1转换到0..W, 0..H
        box = box * torch.Tensor([W, H, W, H])
        # 从xywh转换到xyxy
        box[:2] -= box[2:] / 2
        box[2:] += box[:2]
        # 坐标转换
        x0, y0, x1, y1 = map(int, box.tolist())

        opencv_boxes.append([x0,y0,x1,y1])
        
        # 随机颜色
        color = tuple(map(int, np.random.randint(0, 255, size=3)))
        
        # 绘制矩形框
        cv2.rectangle(image_cv2, (x0, y0), (x1, y1), color, 2)
        
        # 绘制标签背景和文字
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 1
        
        # 获取文本大小
        (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
        
        # 绘制文本背景
        cv2.rectangle(image_cv2, (x0, y0 - text_height - 5), 
                      (x0 + text_width, y0), color, -1)
        
        # 绘制文本
        cv2.putText(image_cv2, label, (x0, y0 - 5), font, 
                   font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
        
    
    return image_cv2,opencv_boxes

def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
    """
    加载模型
    """
    try:
        args = SLConfig.fromfile(model_config_path)
        args.device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
        model = build_model(args)
        checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
        load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
        print("模型加载结果:", load_res)
        model.eval()
        return model
    except Exception as e:
        raise RuntimeError(f"加载模型失败: {str(e)}")


def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, 
                        with_logits=True, cpu_only=False, min_area=0):
    """
    获取模型的检测输出,添加面积过滤
    """
    if text_threshold is None:
        raise ValueError("text_threshold不能为None")
    
    caption = caption.lower().strip()
    if not caption.endswith("."):
        caption += "."
    
    device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
    model = model.to(device)
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    
    logits = outputs["pred_logits"].sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"][0]  # (nq, 4)

    # 过滤输出
    logits_filt = logits.cpu().clone()
    boxes_filt = boxes.cpu().clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]
    boxes_filt = boxes_filt[filt_mask]

    # 获取短语
    tokenizer = model.tokenizer
    tokenized = tokenizer(caption)
    pred_phrases = []
    valid_boxes = []
    
    for logit, box in zip(logits_filt, boxes_filt):
        # 计算目标面积(归一化坐标)
        area = (box[2] * box[3]) * (image.shape[2] * image.shape[1])  # 转为像素面积
        if area < min_area:
            continue
            
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer)
        if with_logits:
            pred_phrases.append(pred_phrase + f"({logit.max().item():.2f})")
        else:
            pred_phrases.append(pred_phrase)
        valid_boxes.append(box)
    
    return torch.stack(valid_boxes) if valid_boxes else torch.empty(0), pred_phrases

def preprocess_cv2_image(image_cv2):
    """
    将OpenCV图像转换为模型输入格式
    """
    # 转换颜色空间 BGR -> RGB
    image_pil = Image.fromarray(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))
    
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image, _ = transform(image_pil, None)
    return image_pil, image

def get_image_files_from_folder(folder_path, sort_by_number=True):
    """
    从文件夹获取所有图像文件,可选按时间戳排序
    """
    supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif')
    image_files = []
    
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(supported_formats):
                image_files.append(os.path.join(root, file))
    
    if sort_by_number:
        # 提取文件名中的数字部分进行排序
        def extract_number(filename):
            # 从文件名中提取数字部分,例如DJI_0004.JPG -> 4
            base = os.path.basename(filename)
            # 去除扩展名
            name_without_ext = os.path.splitext(base)[0]
            # 提取数字部分
            numbers = ''.join(filter(str.isdigit, name_without_ext))
            return int(numbers) if numbers else 0
        
        image_files.sort(key=extract_number)
    
    return image_files

def process_video(model, config):
    """
    处理视频或摄像头输入
    """
    cap = cv2.VideoCapture(config.video_path)
    if not cap.isOpened():
        raise RuntimeError(f"无法打开视频源: {config.video_path}")
    
    print("开始实时检测,按ESC键退出...")
    
    cv2.namedWindow('Video_Detection', cv2.WINDOW_NORMAL)
    cv2.resizeWindow('Video_Detection', 640, 480)


    while True:
        ret, frame = cap.read()
        if not ret:
            print("无法获取视频帧")
            break
        
        _, image_tensor = preprocess_cv2_image(frame)
        
        # 运行模型
        start_time = time.time()
        boxes_filt, pred_phrases = get_grounding_output(
            model, image_tensor, config.text_prompt, 
            config.box_threshold, config.text_threshold,
            cpu_only=config.cpu_only,
            min_area=config.min_target_area

        )
        elapsed_time = time.time() - start_time
        
        # 在图像上绘制检测结果
        if len(boxes_filt) > 0:
            frame ,opencv_boxes = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)
            
            selector.boxes=opencv_boxes

            # mask_display, result_image = selector.perform_segmentation(frame)

            # if mask_display is not None:
            #     cv2.imshow(selector.mask_window_name, mask_display)
            #     cv2.imshow(selector.result_window_name, result_image)
            #     selector.result_image = result_image


        
        
        # 显示FPS
        fps = 1 / elapsed_time if elapsed_time > 0 else 0
        cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        
        if config.show_results:
            cv2.imshow("Video_Detection", frame)
        
        if config.save_results:
            output_path = os.path.join(config.output_dir, f"frame_{int(time.time())}.jpg")
            cv2.imwrite(output_path, frame)
        
        # 按ESC键退出
        if cv2.waitKey(1) == 27:
            break
    
    cap.release()
    if config.show_results:
        cv2.destroyAllWindows()

def process_folder(model, config):
    """
    处理文件夹中的图像
    """
    image_files = get_image_files_from_folder(config.folder_path, config.sort_by_timestamp)
    if not image_files:
        print(f"在文件夹 {config.folder_path} 中未找到图像文件")
        return
    
    print(f"找到 {len(image_files)} 张图像,开始处理...")

    cv2.namedWindow('Image_Detection', cv2.WINDOW_NORMAL)
    cv2.resizeWindow('Image_Detection', 640, 480)
    
    for i, image_path in enumerate(image_files):
        print(f"处理图像 {i+1}/{len(image_files)}: {image_path}")
        
        try:
            frame = cv2.imread(image_path)

            h, w = frame.shape[:2]
            if config.img_scale !=1:
                target_size = (int(w/config.img_scale), int(h/config.img_scale))
                frame = cv2.resize(frame, target_size)

            if frame is None:
                print(f"无法读取图像: {image_path}")
                continue
            
            _, image_tensor = preprocess_cv2_image(frame)
            
            # 目标检测运行模型
            start_time = time.time()
            boxes_filt, pred_phrases = get_grounding_output(
                model, image_tensor, config.text_prompt, 
                config.box_threshold, config.text_threshold,
                cpu_only=config.cpu_only,
                min_area=config.min_target_area
            )
          
            
     
                  # 在图像上绘制检测结果
            if len(boxes_filt) > 0:
                #frame = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)


                frame ,opencv_boxes = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)

                # opencv_boxes.clear()
                # opencv_boxes.append([10,10,1800,1200])

            
                selector.boxes=opencv_boxes
                
                #调用识别
                mask_display, result_image = selector.perform_segmentation(frame)


            elapsed_time = time.time() - start_time
            print("目标检测和跟踪总处理时间",elapsed_time)
            
      
            
            # 显示处理信息
            fps = 1 / elapsed_time if elapsed_time > 0 else 0
            info_text = f"Image {i+1}/{len(image_files)} - FPS: {fps:.1f}"
            cv2.putText(frame, info_text, (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)


            if mask_display is not None:
                cv2.imshow(selector.mask_window_name, mask_display)
                cv2.imshow(selector.result_window_name, result_image)
                selector.result_image = result_image


            if config.show_results:
                cv2.imshow("Image_Detection", frame)
                if cv2.waitKey(0) == 27:
                    break
            
            # if config.save_results:
            #     output_filename = os.path.basename(image_path)
            #     output_path = os.path.join(config.output_dir, output_filename)
            #     cv2.imwrite(output_path, frame)
            #     print(f"结果已保存到: {output_path}")
        
        except Exception as e:
            print(f"处理图像 {image_path} 时出错: {str(e)}")
    
    if config.show_results:
        cv2.destroyAllWindows()

def main():
    config = Config()
    
    # 创建输出目录
    os.makedirs(config.output_dir, exist_ok=True)
    

    gdino_path="/home/r9000k/v2_project/v5_samyolo/1目标检测/GroundingDINO-main"
    # 根据模型类型选择配置
    if config.model_type == "SwinB":
        config_file = gdino_path + "/groundingdino/config/GroundingDINO_SwinB_cfg.py"
        checkpoint_path = gdino_path + "/weights/groundingdino_swinb_cogcoor.pth"
    else:
        config_file = gdino_path+ "/groundingdino/config/GroundingDINO_SwinT_OGC.py"
        checkpoint_path = gdino_path + "/weights/groundingdino_swint_ogc.pth"
    
    try:
        # 加载模型
        print(f"正在加载 {config.model_type} 模型...")
        model = load_model(config_file, checkpoint_path, config.cpu_only)
        print("模型加载完成")
        
        # 根据输入类型选择处理方式
        if config.input_type == "video":
            process_video(model, config)
        elif config.input_type == "folder":
            process_folder(model, config)
        else:
            raise ValueError(f"不支持的输入类型: {config.input_type}")
        
    except Exception as e:
        print(f"发生错误: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()

  

API_SAM2.py

import cv2
import torch
import time
import numpy as np
import os
import sys


# # 检查文件是否存在
# image_path = "npu2pm.JPG"
# if not os.path.exists(image_path):
#     print(f"错误:图像文件 '{image_path}' 不存在!")
#     print("请确保图像文件在当前目录下")
#     sys.exit(1)

# print("图像文件存在,继续执行...")

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 性能优化配置
torch.backends.cudnn.benchmark = True
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# 图像和模型配置
mode_test = 'tiny'
scale = 1.0

model_config = {
    "tiny": ("sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"),
    "small": ("sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"),
    "base": ("sam2.1_hiera_b.yaml", "sam2.1_hiera_base_plus.pt"),
    "large": ("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt")
}

model_type, model_path = model_config[mode_test]
checkpoint = f"../checkpoints/{model_path}"
model_cfg = f"../sam2/configs/sam2.1/{model_type}"

print(f"模型配置: {model_cfg}")
print(f"检查点: {checkpoint}")

# 检查模型文件是否存在
if not os.path.exists(checkpoint.replace("../checkpoints/", "")) and not os.path.exists(checkpoint):
    print(f"警告:模型文件可能不存在于: {checkpoint}")

# # 加载并预处理图像
# print("正在加载图像...")
# image_cv = cv2.imread(image_path)
# if image_cv is None:
#     raise ValueError("无法加载图像!")

# height, width = image_cv.shape[:2]
# print(f"原始图像尺寸: {width}x{height}")

# new_width = int(width * scale)
# new_height = int(height * scale)

# image = cv2.resize(image_cv, (new_width, new_height))
# image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# print(f"调整后图像尺寸: {new_width}x{new_height}")

# 构建模型
print("正在加载SAM2模型...")
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    start_load = time.time()
    sam2 = build_sam2(model_cfg, checkpoint, device=device)
    predictor = SAM2ImagePredictor(sam2)
    end_load = time.time()
    print(f"模型加载完成,耗时: {end_load - start_load:.2f} 秒")
    
except ImportError as e:
    print(f"导入错误: {e}")
    print("请确保sam2模块在Python路径中")
    sys.exit(1)
except Exception as e:
    print(f"模型加载错误: {e}")
    sys.exit(1)
    

# 清理GPU缓存
if device.type == "cuda":
    torch.cuda.empty_cache()

# 交互式选择框
class CameraBoxSelector:
    def __init__(self):
     
    
    
        self.boxes = []  # 存储所有框
        self.result_image = None
        self.mask_window_name = "Segmentation Mask"
        self.result_window_name = "Segmentation Result"
       


        
        # 创建显示窗口

        cv2.namedWindow(self.mask_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.mask_window_name, 800, 600)
        cv2.namedWindow(self.result_window_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.result_window_name, 800, 600)
        
       
    
   

    def perform_segmentation(self, frame):
        if not self.boxes:
            print("没有选择任何框,请先选择框")
            return None, None
            
        #print(f"处理 {len(self.boxes)} 个框的分割...")
        start_time = time.time()

        try:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            predictor.set_image(frame_rgb)
            
            # 为每个框执行分割
            all_masks = []
            all_scores = []
            
            for i, box in enumerate(self.boxes):
                #print(f"处理框 {i+1}: {box}")
                masks, scores, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=np.array(box),
                    multimask_output=False
                )
                
                if len(masks) > 0:
                    all_masks.append(masks[0])
                    all_scores.append(scores[0])
            
            if not all_masks:
                print("未生成任何掩码")
                return None, None
                
            end_time = time.time()
            processing_time = end_time - start_time

            #print(f"分割完成,耗时: {processing_time:.2f} 秒")
            #print(f"生成掩码数量: {len(all_masks)}")

            # 创建纯mask显示图像
            mask_display = np.zeros_like(frame)
            
            # 创建结果图像 - 初始化为原图
            result_image = frame.copy()
            
            for i, (mask, score) in enumerate(zip(all_masks, all_scores)):
                #print(f"处理掩码 {i+1}, 分数: {score:.3f}")
                
                if hasattr(mask, 'cpu'):
                    mask = mask.cpu().numpy()
                
                if mask.dtype != bool:
                    mask = mask.astype(bool)
                
                # 生成鲜艳的颜色
                color = [
                    np.random.randint(150, 256),
                    np.random.randint(150, 256),
                    np.random.randint(150, 256)
                ]
                
                # mask透明度
                mask_alpha = 0.5  # 50%透明度
                border_width = 2   # 边界宽度
                
                # 1. 更新纯mask显示
                for c in range(3):
                    mask_display[:, :, c][mask] = color[c]
                
                # 2. 更新结果图像 - 只在mask区域进行叠加
                # 创建彩色掩码
                colored_mask = np.zeros_like(result_image)
                for c in range(3):
                    colored_mask[:, :, c][mask] = color[c]
                
                # 创建alpha通道 - mask区域为0.5,其他区域为0
                alpha = np.zeros((result_image.shape[0], result_image.shape[1]), dtype=np.float32)
                alpha[mask] = mask_alpha
                alpha = np.dstack([alpha]*3)  # 转换为3通道
                
                # 只在mask区域进行混合
                result_image = (result_image * (1 - alpha) + colored_mask * alpha).astype(np.uint8)
                
                # 添加边界
                contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(result_image, contours, -1, color, border_width)
                
                # 显示分数
                y_coords, x_coords = np.where(mask)
                if len(x_coords) > 0:
                    center_x, center_y = np.mean(x_coords), np.mean(y_coords)
                    cv2.putText(result_image, f"{i+1}:{score:.3f}", 
                               (int(center_x), int(center_y)), 
                               cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            
            return mask_display, result_image
                
        except Exception as e:
            print(f"分割错误: {e}")
            import traceback
            traceback.print_exc()
            return None, None
    

    


# # 创建选择器实例并运行
# try:
#     selector = CameraBoxSelector()
#     selector.run()
# except Exception as e:
#     print(f"初始化错误: {e}")
#     sys.exit(1)

# # 清理内存
# if device.type == "cuda":
#     torch.cuda.empty_cache()

  

 

posted on 2025-10-28 00:59  MKT-porter  阅读(4)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3