
测试代码

import os
import cv2
import torch
import numpy as np
import time
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from transformers.image_utils import load_image
import matplotlib.pyplot as plt
# 设置CUDA内存配置
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
# 在预处理时添加resize操作
def preprocess_image(image, scale):
# 保持宽高比缩放,短边=target_size
width, height = image.size
#scale = target_size / min(width, height)
new_size = (int(width / scale), int(height / scale))
return image.resize(new_size)
# 初始化模型和处理器
def initialize_model(model_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForZeroShotObjectDetection.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
).to(device)
return processor, model, device
# 执行目标检测
def detect_objects(image, processor, model, device, text_labels):
inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
threshold=0.3,
target_sizes=[(image.height, image.width)]
)
return results[0]
# 可视化检测结果(添加FPS显示)
def visualize_detection(image, result, fps=None):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
# 绘制检测结果
for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
if score < 0.3: # 使用阈值过滤
continue
xmin, ymin, xmax, ymax = [int(round(coord)) for coord in box.tolist()]
color = colors[i % len(colors)]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
label_text = f"{label}: {score.item():.2f}"
(text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img, (xmin, ymin - text_height - 10), (xmin + text_width, ymin), color, -1)
cv2.putText(img, label_text, (xmin, ymin - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 添加FPS显示
if fps is not None:
fps_text = f"FPS: {fps:.1f}"
cv2.putText(img, fps_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 主函数:处理文件夹中的图像(添加FPS计算)
def process_folder_images(folder_path, model_path,img_scale=1):
# 获取并排序所有DJI_*.JPG文件
image_files = sorted([f for f in os.listdir(folder_path)
if f.startswith('DJI_') and f.lower().endswith('.jpg')])
if not image_files:
print("未找到DJI_*.JPG格式的图像文件")
return
# 初始化模型
processor, model, device = initialize_model(model_path)
text_labels = ["vehicle", "person", "building", "tree",
"power line", "agricultural machinery", "water body"]
# 创建可调整大小的窗口
cv2.namedWindow('Zero-Shot Object Detection', cv2.WINDOW_NORMAL)
current_index = 0
total_images = len(image_files)
# FPS计算变量
fps = 0
prev_time = 0
curr_time = 0
while True:
# 开始计时
start_time = time.time()
# 加载当前图像
image_path = os.path.join(folder_path, image_files[current_index])
image = load_image(image_path)
image = preprocess_image(image,img_scale) # 缩放2倍
# 执行检测
result = detect_objects(image, processor, model, device, text_labels)
# 计算处理时间
inference_time = time.time() - start_time
fps = 1.0 / inference_time if inference_time > 0 else 0
# 可视化结果(传入FPS)
result_img = visualize_detection(image, result, fps)
# 显示结果
cv2.imshow('Zero-Shot Object Detection', cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR))
# 打印处理信息(包含FPS)
print(f"处理: {image_files[current_index]} ({current_index + 1}/{total_images}) | FPS: {fps:.1f}")
#print(torch.cuda.memory_summary()) # 打印显存分配情况
# 等待按键
key = cv2.waitKey(0) & 0xFF
# 按键处理
if key == 27 or key == ord('q'): # ESC或q退出
break
elif key == ord('n') or key == 32 or key == 83 or key == 2: # 下一张
current_index = (current_index + 1) % total_images
elif key == ord('p') or key == 81 or key == 3: # 上一张
current_index = (current_index - 1) % total_images
cv2.destroyAllWindows()
# 使用示例
if __name__ == "__main__":
folder_path = "/media/r9000k/DD_XS/2数据/2RTK/data_1_nwpuUp/data3_1130_13pm/300_location_14pm/images" # 图像文件夹路径
model_path = "./" # 模型路径
img_scale=1 # 缩放
process_folder_images(folder_path, model_path,img_scale)






浙公网安备 33010602011771号