• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
语言-目标检测 MM Grounding Dino Large (1) 实验对比 -GroundingDino针对航空图像检测的改进

 

image

 

测试代码

image

 

import os
import cv2
import torch
import numpy as np
import time
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from transformers.image_utils import load_image
import matplotlib.pyplot as plt

# 设置CUDA内存配置
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()


# 在预处理时添加resize操作
def preprocess_image(image, scale):
    # 保持宽高比缩放,短边=target_size
    width, height = image.size
    #scale = target_size / min(width, height)
    new_size = (int(width / scale), int(height / scale))
    return image.resize(new_size)



# 初始化模型和处理器
def initialize_model(model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    processor = AutoProcessor.from_pretrained(model_path)
    model = AutoModelForZeroShotObjectDetection.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto"
    ).to(device)
    return processor, model, device

# 执行目标检测
def detect_objects(image, processor, model, device, text_labels):
    inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
    
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        outputs = model(**inputs)
    
    results = processor.post_process_grounded_object_detection(
        outputs,
        threshold=0.3,
        target_sizes=[(image.height, image.width)]
    )
    return results[0]

# 可视化检测结果(添加FPS显示)
def visualize_detection(image, result, fps=None):
    img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
    
    # 绘制检测结果
    for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
        if score < 0.3:  # 使用阈值过滤
            continue
            
        xmin, ymin, xmax, ymax = [int(round(coord)) for coord in box.tolist()]
        color = colors[i % len(colors)]
        
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
        
        label_text = f"{label}: {score.item():.2f}"
        (text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(img, (xmin, ymin - text_height - 10), (xmin + text_width, ymin), color, -1)
        cv2.putText(img, label_text, (xmin, ymin - 5), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
    
    # 添加FPS显示
    if fps is not None:
        fps_text = f"FPS: {fps:.1f}"
        cv2.putText(img, fps_text, (10, 30), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 主函数:处理文件夹中的图像(添加FPS计算)
def process_folder_images(folder_path, model_path,img_scale=1):
    # 获取并排序所有DJI_*.JPG文件
    image_files = sorted([f for f in os.listdir(folder_path) 
                         if f.startswith('DJI_') and f.lower().endswith('.jpg')])
    
    if not image_files:
        print("未找到DJI_*.JPG格式的图像文件")
        return
    
    # 初始化模型
    processor, model, device = initialize_model(model_path)
    text_labels = ["vehicle", "person", "building", "tree", 
                   "power line", "agricultural machinery", "water body"]
    
    # 创建可调整大小的窗口
    cv2.namedWindow('Zero-Shot Object Detection', cv2.WINDOW_NORMAL)
    
    current_index = 0
    total_images = len(image_files)
    
    # FPS计算变量
    fps = 0
    prev_time = 0
    curr_time = 0
    
    while True:
        # 开始计时
        start_time = time.time()
        
        # 加载当前图像
        image_path = os.path.join(folder_path, image_files[current_index])
        image = load_image(image_path)

        image = preprocess_image(image,img_scale) # 缩放2倍
        
        # 执行检测
        result = detect_objects(image, processor, model, device, text_labels)
        
        # 计算处理时间
        inference_time = time.time() - start_time
        fps = 1.0 / inference_time if inference_time > 0 else 0
        
        # 可视化结果(传入FPS)
        result_img = visualize_detection(image, result, fps)
        
        # 显示结果
        cv2.imshow('Zero-Shot Object Detection', cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR))
        
        # 打印处理信息(包含FPS)
        print(f"处理: {image_files[current_index]} ({current_index + 1}/{total_images}) | FPS: {fps:.1f}")
        #print(torch.cuda.memory_summary())  # 打印显存分配情况
        # 等待按键
        key = cv2.waitKey(0) & 0xFF
        
        # 按键处理
        if key == 27 or key == ord('q'):  # ESC或q退出
            break
        elif key == ord('n') or key == 32 or key == 83 or key == 2:  # 下一张
            current_index = (current_index + 1) % total_images
        elif key == ord('p') or key == 81 or key == 3:  # 上一张
            current_index = (current_index - 1) % total_images
    
    cv2.destroyAllWindows()

# 使用示例
if __name__ == "__main__":
    folder_path = "/media/r9000k/DD_XS/2数据/2RTK/data_1_nwpuUp/data3_1130_13pm/300_location_14pm/images"  # 图像文件夹路径
    model_path = "./"   # 模型路径
    img_scale=1 # 缩放
    process_folder_images(folder_path, model_path,img_scale)

  

 

image

 

image

image

image

 

image

 

image

 

posted on 2025-10-28 05:01  MKT-porter  阅读(2)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3