说明
1 sam2的代码文件依赖是相对于sam2工程的,所以在sam2下面构建
2 需要训练空中数据集
3 不能直接全直接sam2分割,分割的会很混乱
安装
环境 rtx 3070 ubuntu20 cuda11.8
python3.10
1安装 GroundingDino
2安装sam2
3 以sma2为根目录创建新工程,创建这个代码,将GroundingDino工程代码拷贝过来或者路径引用
反过来使用全局路径引用sam2不行,因为他是安装自己内部路径相对位置找的依赖文件






import os
import sys
import time
import warnings
import numpy as np
import torch
import cv2
from PIL import Image, ImageDraw, ImageFont
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.vl_utils import create_positive_map_from_span
from groundingdino.util.inference import load_model, load_image, predict, annotate
# 配置警告过滤器
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from API_SAM2 import *
# import sys
# path_sam2= '/home/r9000k/v2_project/v5_samyolo/2分割/sam2-main/test'
# sys.path.append(path_sam2)
selector = CameraBoxSelector()
# 无人机对地检测目标列表(排除树木和人)
DRONE_TARGETS = [
# 车辆类
"vehicle", "car", "truck", "bus", "van", "SUV", "motorcycle", "bicycle",
"construction vehicle", "excavator", "bulldozer", "crane", "forklift",
"tractor", "trailer", "ambulance", "fire truck", "police car",
# 建筑物和结构
"building", "house", "apartment", "commercial building", "factory",
"warehouse", "shed", "garage", "roof", "chimney",
"bridge", "overpass", "tunnel", "dam", "power plant",
# 道路和交通设施
"road", "highway", "street", "pavement", "crosswalk", "roundabout",
"traffic light", "street light", "road sign", "billboard",
"parking lot", "gas station", "bus stop",
# 水域相关
"river", "lake", "pond", "reservoir", "swimming pool", "fountain",
"boat", "ship", "yacht", "speedboat", "dock", "pier", "harbor",
# 农业相关
"farmland", "crop field", "greenhouse", "barn", "silo", "windmill",
"irrigation system", "livestock pen",
# 能源设施
"solar panel", "wind turbine", "power line", "transformer",
"oil rig", "oil tank", "gas pipeline",
# 运动场地
"soccer field", "basketball court", "tennis court", "baseball field",
"swimming pool", "stadium", "running track",
# 基础设施
"airport", "runway", "hangar", "airplane", "helicopter",
"railway", "train", "railroad track", "train station",
"cell tower", "communication tower", "satellite dish",
# 军事和安全设施(可选)
"military vehicle", "barracks", "checkpoint", "fence", "gate",
# 其他重要目标
"container", "shipping container", "cargo", "construction material",
"playground equipment", "park bench", "statue", "monument"
]
DRONE_TARGETS_min = [
# 建筑物和结构
"building", "house", "apartment", "commercial building", "factory",
"warehouse", "shed", "garage", "roof", "chimney",
"bridge", "overpass",
"gray building","white building","large building","red playground","dark brown building",
'car',
# 道路和交通设施
"road", "highway", "street", "pavement", "crosswalk",
# 运动场地
"soccer field", "basketball court",
]
class Config:
def __init__(self):
# 模型配置
self.model_type = "SwinB" # "SwinB" 938mb 或 "SwinT" 600mb
#self.text_prompt = "building, person, door, cap" # 检测文本提示
self.text_prompt = ", ".join(DRONE_TARGETS_min)
'''
官方
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
'''
self.box_threshold = 0.2 # 提高框阈值,减少误检
self.text_threshold = 0.2 # 降低文本阈值,提高小目标召回
self.cpu_only = False # 仅使用CPU运行
# 输入源配置
self.input_type = "folder" # "video"或"folder"
self.video_path = 0 # 视频路径或摄像头ID
self.folder_path = "/media/r9000k/DD_XS/2数据/2RTK/data_4_city/460_500/images" # 图像文件夹路径
#/media/r9000k/DD_XS/2数据/2RTK/data_4_city/300_map_2pm/images
#"/home/r9000k/v0_data/rtk/nwpu_1130_12pm" # 图像文件夹路径
self.img_scale=1# 图像缩放系数
# 输出配置
self.output_dir = "outputs" # 输出目录
self.save_results = True # 是否保存结果
self.show_results = True # 是否显示结果
# 其他配置
# 后处理配置
self.min_target_area = 0 # 最小目标面积(像素),过滤过小目标
self.sort_by_timestamp = True # 是否按时间戳排序图像
def plot_boxes_to_image_cv2(image_cv2, boxes, labels):
"""
在OpenCV图像上绘制检测框和标签
"""
H, W = image_cv2.shape[:2]
opencv_boxes=[]
for box, label in zip(boxes, labels):
# 从0..1转换到0..W, 0..H
box = box * torch.Tensor([W, H, W, H])
# 从xywh转换到xyxy
box[:2] -= box[2:] / 2
box[2:] += box[:2]
# 坐标转换
x0, y0, x1, y1 = map(int, box.tolist())
opencv_boxes.append([x0,y0,x1,y1])
# 随机颜色
color = tuple(map(int, np.random.randint(0, 255, size=3)))
# 绘制矩形框
cv2.rectangle(image_cv2, (x0, y0), (x1, y1), color, 2)
# 绘制标签背景和文字
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1
# 获取文本大小
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
# 绘制文本背景
cv2.rectangle(image_cv2, (x0, y0 - text_height - 5),
(x0 + text_width, y0), color, -1)
# 绘制文本
cv2.putText(image_cv2, label, (x0, y0 - 5), font,
font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
return image_cv2,opencv_boxes
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
"""
加载模型
"""
try:
args = SLConfig.fromfile(model_config_path)
args.device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print("模型加载结果:", load_res)
model.eval()
return model
except Exception as e:
raise RuntimeError(f"加载模型失败: {str(e)}")
def get_grounding_output(model, image, caption, box_threshold, text_threshold=None,
with_logits=True, cpu_only=False, min_area=0):
"""
获取模型的检测输出,添加面积过滤
"""
if text_threshold is None:
raise ValueError("text_threshold不能为None")
caption = caption.lower().strip()
if not caption.endswith("."):
caption += "."
device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
# 过滤输出
logits_filt = logits.cpu().clone()
boxes_filt = boxes.cpu().clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask]
boxes_filt = boxes_filt[filt_mask]
# 获取短语
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
pred_phrases = []
valid_boxes = []
for logit, box in zip(logits_filt, boxes_filt):
# 计算目标面积(归一化坐标)
area = (box[2] * box[3]) * (image.shape[2] * image.shape[1]) # 转为像素面积
if area < min_area:
continue
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({logit.max().item():.2f})")
else:
pred_phrases.append(pred_phrase)
valid_boxes.append(box)
return torch.stack(valid_boxes) if valid_boxes else torch.empty(0), pred_phrases
def preprocess_cv2_image(image_cv2):
"""
将OpenCV图像转换为模型输入格式
"""
# 转换颜色空间 BGR -> RGB
image_pil = Image.fromarray(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None)
return image_pil, image
def get_image_files_from_folder(folder_path, sort_by_number=True):
"""
从文件夹获取所有图像文件,可选按时间戳排序
"""
supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif')
image_files = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.lower().endswith(supported_formats):
image_files.append(os.path.join(root, file))
if sort_by_number:
# 提取文件名中的数字部分进行排序
def extract_number(filename):
# 从文件名中提取数字部分,例如DJI_0004.JPG -> 4
base = os.path.basename(filename)
# 去除扩展名
name_without_ext = os.path.splitext(base)[0]
# 提取数字部分
numbers = ''.join(filter(str.isdigit, name_without_ext))
return int(numbers) if numbers else 0
image_files.sort(key=extract_number)
return image_files
def process_video(model, config):
"""
处理视频或摄像头输入
"""
cap = cv2.VideoCapture(config.video_path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频源: {config.video_path}")
print("开始实时检测,按ESC键退出...")
cv2.namedWindow('Video_Detection', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Video_Detection', 640, 480)
while True:
ret, frame = cap.read()
if not ret:
print("无法获取视频帧")
break
_, image_tensor = preprocess_cv2_image(frame)
# 运行模型
start_time = time.time()
boxes_filt, pred_phrases = get_grounding_output(
model, image_tensor, config.text_prompt,
config.box_threshold, config.text_threshold,
cpu_only=config.cpu_only,
min_area=config.min_target_area
)
elapsed_time = time.time() - start_time
# 在图像上绘制检测结果
if len(boxes_filt) > 0:
frame ,opencv_boxes = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)
selector.boxes=opencv_boxes
# mask_display, result_image = selector.perform_segmentation(frame)
# if mask_display is not None:
# cv2.imshow(selector.mask_window_name, mask_display)
# cv2.imshow(selector.result_window_name, result_image)
# selector.result_image = result_image
# 显示FPS
fps = 1 / elapsed_time if elapsed_time > 0 else 0
cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if config.show_results:
cv2.imshow("Video_Detection", frame)
if config.save_results:
output_path = os.path.join(config.output_dir, f"frame_{int(time.time())}.jpg")
cv2.imwrite(output_path, frame)
# 按ESC键退出
if cv2.waitKey(1) == 27:
break
cap.release()
if config.show_results:
cv2.destroyAllWindows()
def process_folder(model, config):
"""
处理文件夹中的图像
"""
image_files = get_image_files_from_folder(config.folder_path, config.sort_by_timestamp)
if not image_files:
print(f"在文件夹 {config.folder_path} 中未找到图像文件")
return
print(f"找到 {len(image_files)} 张图像,开始处理...")
cv2.namedWindow('Image_Detection', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Image_Detection', 640, 480)
for i, image_path in enumerate(image_files):
print(f"处理图像 {i+1}/{len(image_files)}: {image_path}")
try:
frame = cv2.imread(image_path)
h, w = frame.shape[:2]
if config.img_scale !=1:
target_size = (int(w/config.img_scale), int(h/config.img_scale))
frame = cv2.resize(frame, target_size)
if frame is None:
print(f"无法读取图像: {image_path}")
continue
_, image_tensor = preprocess_cv2_image(frame)
# 目标检测运行模型
start_time = time.time()
boxes_filt, pred_phrases = get_grounding_output(
model, image_tensor, config.text_prompt,
config.box_threshold, config.text_threshold,
cpu_only=config.cpu_only,
min_area=config.min_target_area
)
# 在图像上绘制检测结果
if len(boxes_filt) > 0:
#frame = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)
frame ,opencv_boxes = plot_boxes_to_image_cv2(frame, boxes_filt, pred_phrases)
# opencv_boxes.clear()
# opencv_boxes.append([10,10,1800,1200])
selector.boxes=opencv_boxes
#调用识别
mask_display, result_image = selector.perform_segmentation(frame)
elapsed_time = time.time() - start_time
print("目标检测和跟踪总处理时间",elapsed_time)
# 显示处理信息
fps = 1 / elapsed_time if elapsed_time > 0 else 0
info_text = f"Image {i+1}/{len(image_files)} - FPS: {fps:.1f}"
cv2.putText(frame, info_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if mask_display is not None:
cv2.imshow(selector.mask_window_name, mask_display)
cv2.imshow(selector.result_window_name, result_image)
selector.result_image = result_image
if config.show_results:
cv2.imshow("Image_Detection", frame)
if cv2.waitKey(0) == 27:
break
# if config.save_results:
# output_filename = os.path.basename(image_path)
# output_path = os.path.join(config.output_dir, output_filename)
# cv2.imwrite(output_path, frame)
# print(f"结果已保存到: {output_path}")
except Exception as e:
print(f"处理图像 {image_path} 时出错: {str(e)}")
if config.show_results:
cv2.destroyAllWindows()
def main():
config = Config()
# 创建输出目录
os.makedirs(config.output_dir, exist_ok=True)
gdino_path="/home/r9000k/v2_project/v5_samyolo/1目标检测/GroundingDINO-main"
# 根据模型类型选择配置
if config.model_type == "SwinB":
config_file = gdino_path + "/groundingdino/config/GroundingDINO_SwinB_cfg.py"
checkpoint_path = gdino_path + "/weights/groundingdino_swinb_cogcoor.pth"
else:
config_file = gdino_path+ "/groundingdino/config/GroundingDINO_SwinT_OGC.py"
checkpoint_path = gdino_path + "/weights/groundingdino_swint_ogc.pth"
try:
# 加载模型
print(f"正在加载 {config.model_type} 模型...")
model = load_model(config_file, checkpoint_path, config.cpu_only)
print("模型加载完成")
# 根据输入类型选择处理方式
if config.input_type == "video":
process_video(model, config)
elif config.input_type == "folder":
process_folder(model, config)
else:
raise ValueError(f"不支持的输入类型: {config.input_type}")
except Exception as e:
print(f"发生错误: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()
API_SAM2.py
import cv2
import torch
import time
import numpy as np
import os
import sys
# # 检查文件是否存在
# image_path = "npu2pm.JPG"
# if not os.path.exists(image_path):
# print(f"错误:图像文件 '{image_path}' 不存在!")
# print("请确保图像文件在当前目录下")
# sys.exit(1)
# print("图像文件存在,继续执行...")
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 性能优化配置
torch.backends.cudnn.benchmark = True
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# 图像和模型配置
mode_test = 'tiny'
scale = 1.0
model_config = {
"tiny": ("sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt"),
"small": ("sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt"),
"base": ("sam2.1_hiera_b.yaml", "sam2.1_hiera_base_plus.pt"),
"large": ("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt")
}
model_type, model_path = model_config[mode_test]
checkpoint = f"../checkpoints/{model_path}"
model_cfg = f"../sam2/configs/sam2.1/{model_type}"
print(f"模型配置: {model_cfg}")
print(f"检查点: {checkpoint}")
# 检查模型文件是否存在
if not os.path.exists(checkpoint.replace("../checkpoints/", "")) and not os.path.exists(checkpoint):
print(f"警告:模型文件可能不存在于: {checkpoint}")
# # 加载并预处理图像
# print("正在加载图像...")
# image_cv = cv2.imread(image_path)
# if image_cv is None:
# raise ValueError("无法加载图像!")
# height, width = image_cv.shape[:2]
# print(f"原始图像尺寸: {width}x{height}")
# new_width = int(width * scale)
# new_height = int(height * scale)
# image = cv2.resize(image_cv, (new_width, new_height))
# image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# print(f"调整后图像尺寸: {new_width}x{new_height}")
# 构建模型
print("正在加载SAM2模型...")
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
start_load = time.time()
sam2 = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2)
end_load = time.time()
print(f"模型加载完成,耗时: {end_load - start_load:.2f} 秒")
except ImportError as e:
print(f"导入错误: {e}")
print("请确保sam2模块在Python路径中")
sys.exit(1)
except Exception as e:
print(f"模型加载错误: {e}")
sys.exit(1)
# 清理GPU缓存
if device.type == "cuda":
torch.cuda.empty_cache()
# 交互式选择框
class CameraBoxSelector:
def __init__(self):
self.boxes = [] # 存储所有框
self.result_image = None
self.mask_window_name = "Segmentation Mask"
self.result_window_name = "Segmentation Result"
# 创建显示窗口
cv2.namedWindow(self.mask_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(self.mask_window_name, 800, 600)
cv2.namedWindow(self.result_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(self.result_window_name, 800, 600)
def perform_segmentation(self, frame):
if not self.boxes:
print("没有选择任何框,请先选择框")
return None, None
#print(f"处理 {len(self.boxes)} 个框的分割...")
start_time = time.time()
try:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
predictor.set_image(frame_rgb)
# 为每个框执行分割
all_masks = []
all_scores = []
for i, box in enumerate(self.boxes):
#print(f"处理框 {i+1}: {box}")
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(box),
multimask_output=False
)
if len(masks) > 0:
all_masks.append(masks[0])
all_scores.append(scores[0])
if not all_masks:
print("未生成任何掩码")
return None, None
end_time = time.time()
processing_time = end_time - start_time
#print(f"分割完成,耗时: {processing_time:.2f} 秒")
#print(f"生成掩码数量: {len(all_masks)}")
# 创建纯mask显示图像
mask_display = np.zeros_like(frame)
# 创建结果图像 - 初始化为原图
result_image = frame.copy()
for i, (mask, score) in enumerate(zip(all_masks, all_scores)):
#print(f"处理掩码 {i+1}, 分数: {score:.3f}")
if hasattr(mask, 'cpu'):
mask = mask.cpu().numpy()
if mask.dtype != bool:
mask = mask.astype(bool)
# 生成鲜艳的颜色
color = [
np.random.randint(150, 256),
np.random.randint(150, 256),
np.random.randint(150, 256)
]
# mask透明度
mask_alpha = 0.5 # 50%透明度
border_width = 2 # 边界宽度
# 1. 更新纯mask显示
for c in range(3):
mask_display[:, :, c][mask] = color[c]
# 2. 更新结果图像 - 只在mask区域进行叠加
# 创建彩色掩码
colored_mask = np.zeros_like(result_image)
for c in range(3):
colored_mask[:, :, c][mask] = color[c]
# 创建alpha通道 - mask区域为0.5,其他区域为0
alpha = np.zeros((result_image.shape[0], result_image.shape[1]), dtype=np.float32)
alpha[mask] = mask_alpha
alpha = np.dstack([alpha]*3) # 转换为3通道
# 只在mask区域进行混合
result_image = (result_image * (1 - alpha) + colored_mask * alpha).astype(np.uint8)
# 添加边界
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(result_image, contours, -1, color, border_width)
# 显示分数
y_coords, x_coords = np.where(mask)
if len(x_coords) > 0:
center_x, center_y = np.mean(x_coords), np.mean(y_coords)
cv2.putText(result_image, f"{i+1}:{score:.3f}",
(int(center_x), int(center_y)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
return mask_display, result_image
except Exception as e:
print(f"分割错误: {e}")
import traceback
traceback.print_exc()
return None, None
# # 创建选择器实例并运行
# try:
# selector = CameraBoxSelector()
# selector.run()
# except Exception as e:
# print(f"初始化错误: {e}")
# sys.exit(1)
# # 清理内存
# if device.type == "cuda":
# torch.cuda.empty_cache()
浙公网安备 33010602011771号