• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
sam3 (4)mask分割对比

 

 

image

 

image

image

 

image

 

 

查找最小圆和外界举行

'''
 
主要修改内容:
1. 匹配策略改进 (match_images_multi_to_one)
修改前:复杂的竞争解决机制
修改后:简单直接的多对一匹配
每个当前帧目标从所有超过阈值的历史候选匹配中选择相似度最高的
允许多个当前目标匹配到同一个历史目标
简化了匹配逻辑,提高了匹配效率
2. 可视化标签改进 (create_matched_image)
当前帧标签显示:
C{i}→H{j}:当前帧目标i匹配到历史帧目标j
Sim:{similarity:.3f}:特征匹配相似度
Mask:{score:.2f}:mask分割置信度
历史帧标签显示:
H{i}←C{j1}(sim1),C{j2}(sim2):历史帧目标i被多个当前目标匹配
显示最多2个匹配的当前目标及相似度
3. 可视化布局优化
增加图像尺寸以适应更多信息显示
添加图例说明标签含义
减小字体大小以显示完整信息
改进颜色编码和边框样式
4. 匹配逻辑优势
新的匹配策略更加符合多对一的场景需求:
一个历史建筑可能被多个当前视角的建筑匹配
每个当前目标独立选择最佳匹配
避免了复杂的冲突解决,提高匹配成功率
这样的修改使得匹配过程更加直观,可视化结果更加清晰,能够准确显示每个目标的匹配状态和关键指标。
 
. 图例说明
使用英文说明标签含义
说明颜色编码和边框样式含义
增加图例背景提高可读性
6. 可视化效果
颜色一致性:相同匹配对在两帧中显示相同颜色
清晰区分:实线边框(已匹配)vs 虚线边框(未匹配)
信息完整:显示目标ID、匹配关系、相似度、分割分数
可读性强:大字体、英文标签、清晰的视觉层次
 
 
 
 
1 多个图片看到的特征融合
2 保存和加载
3 视角差异
方向度对齐 最小矩形框  选择方向 旋转 然后送入特征检测
平移
尺度对齐
4 阈值过滤
 
'''
 
 
 
 
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time
import os
import glob
from scipy.optimize import linear_sum_assignment
import colorsys
import gc
import matplotlib
import json
 
matplotlib.use('TkAgg')
 
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="tkinter")
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")
 
# 新增导入
import cv2
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
 
# 内存优化设置
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
 
#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
 
class ReIDNetwork(nn.Module):
    """ReID网络用于提取目标外观特征"""
     
    def __init__(self, feature_dim=512):
        super(ReIDNetwork, self).__init__()
        # 使用预训练的ResNet作为骨干网络
        self.backbone = models.resnet50(pretrained=True)
        # 移除分类层
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
         
        # 全局平均池化
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
         
        # 特征降维
        self.feature_reduction = nn.Sequential(
            nn.Linear(2048, feature_dim),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.BatchNorm1d(feature_dim // 2),
            nn.ReLU(inplace=True),
        )
         
        self.feature_dim = feature_dim // 2
         
        # 图像预处理
        self.transform = transforms.Compose([
            transforms.Resize((256, 128)),  # ReID标准尺寸
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
     
    def forward(self, x):
        """前向传播"""
        features = self.backbone(x)
        features = self.global_avg_pool(features)
        features = features.view(features.size(0), -1)
        features = self.feature_reduction(features)
        # L2归一化
        features = nn.functional.normalize(features, p=2, dim=1)
        return features
     
    def extract_features_from_crop(self, image, box):
        """从图像裁剪中提取特征"""
        try:
            # 裁剪目标区域
            x1, y1, x2, y2 = map(int, box)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(image.width, x2)
            y2 = min(image.height, y2)
             
            if x2 <= x1 or y2 <= y1:
                return None
                 
            crop = image.crop((x1, y1, x2, y2))
             
            # 转换为RGB(处理可能的RGBA图像)
            if crop.mode != 'RGB':
                crop = crop.convert('RGB')
             
            # 预处理
            crop_tensor = self.transform(crop).unsqueeze(0)
             
            # 提取特征
            with torch.no_grad():
                features = self.forward(crop_tensor)
             
            return features.squeeze(0).cpu().numpy()
             
        except Exception as e:
            print(f"ReID特征提取错误: {e}")
            return None
 
    def extract_features_from_mask(self, image, mask):
        """从mask区域提取特征"""
        try:
            # 将mask转换为numpy数组
            mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
             
            # 找到mask的边界框
            rows = np.any(mask_np, axis=1)
            cols = np.any(mask_np, axis=0)
             
            if not np.any(rows) or not np.any(cols):
                return None
                 
            y1, y2 = np.where(rows)[0][[0, -1]]
            x1, x2 = np.where(cols)[0][[0, -1]]
             
            # 扩展边界框
            padding = 5
            x1 = max(0, x1 - padding)
            y1 = max(0, y1 - padding)
            x2 = min(image.width, x2 + padding)
            y2 = min(image.height, y2 + padding)
             
            if x2 <= x1 or y2 <= y1:
                return None
                 
            # 裁剪目标区域
            crop = image.crop((x1, y1, x2, y2))
             
            # 转换为RGB
            if crop.mode != 'RGB':
                crop = crop.convert('RGB')
             
            # 预处理
            crop_tensor = self.transform(crop).unsqueeze(0)
             
            # 提取特征
            with torch.no_grad():
                features = self.forward(crop_tensor)
             
            return features.squeeze(0).cpu().numpy()
             
        except Exception as e:
            print(f"从mask提取ReID特征错误: {e}")
            return None
 
class ImageMatcher:
    """图像匹配器,支持多对一匹配策略"""
     
    def __init__(self, reid_threshold=0.7, matching_strategy="multi_to_one"):
        self.reid_threshold = reid_threshold
        self.matching_strategy = matching_strategy  # "multi_to_one", "one_to_one", "greedy"
        self.reid_net = ReIDNetwork()
        self.reid_net.eval()
        self.match_colors = {}
        self.target_colors = {}
     
    def extract_features_for_image(self, image, masks, boxes):
        """为图像的所有分割结果提取特征"""
        features = []
        for mask, box in zip(masks, boxes):
            # 优先使用mask区域提取特征
            feature = self.reid_net.extract_features_from_mask(image, mask)
            if feature is None:
                # 如果mask提取失败,使用边界框
                feature = self.reid_net.extract_features_from_crop(image, box)
            features.append(feature)
        return features
     
    def calculate_cosine_similarity(self, feat1, feat2):
        """计算两个特征向量的余弦相似度"""
        if feat1 is None or feat2 is None:
            return 0.0
        return np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2) + 1e-8)
     
    def match_images_multi_to_one(self, current_image, current_masks, current_boxes,
                                history_image, history_masks, history_boxes):
        """多对一匹配:当前帧目标从所有超过阈值的历史候选匹配选取最高匹配作为匹配结果"""
         
        # 提取特征
        current_features = self.extract_features_for_image(current_image, current_masks, current_boxes)
        history_features = self.extract_features_for_image(history_image, history_masks, history_boxes)
         
        n_current = len(current_features)
        n_history = len(history_features)
         
        if n_current == 0 or n_history == 0:
            return [], current_features, history_features
         
        # 构建相似度矩阵
        similarity_matrix = np.zeros((n_current, n_history))
        for i in range(n_current):
            for j in range(n_history):
                similarity = self.calculate_cosine_similarity(current_features[i], history_features[j])
                similarity_matrix[i, j] = similarity
         
        print(f"相似度矩阵构建完成: {n_current} x {n_history}")
         
        # 新的匹配策略:为每个当前帧目标选择相似度最高的历史帧目标(超过阈值)
        matches = []
        used_history_indices = set()  # 记录已匹配的历史帧目标
         
        for i in range(n_current):
            # 获取当前目标i与所有历史目标的相似度
            similarities = similarity_matrix[i, :]
             
            # 找到超过阈值的最佳匹配
            best_match_idx = -1
            best_similarity = 0.0
             
            for j in range(n_history):
                if similarities[j] >= self.reid_threshold and similarities[j] > best_similarity:
                    # 检查历史目标是否已被匹配(多对一允许重复匹配)
                    # 这里我们允许历史目标被多个当前目标匹配
                    best_match_idx = j
                    best_similarity = similarities[j]
             
            if best_match_idx != -1:
                matches.append((i, best_match_idx, best_similarity))
                used_history_indices.add(best_match_idx)
                print(f"✓ 匹配成功: 当前帧目标{i} → 历史帧目标{best_match_idx}, 相似度: {best_similarity:.3f}")
            else:
                max_sim = np.max(similarities) if n_history > 0 else 0
                print(f"❌ 当前帧目标{i}: 无超过阈值的历史匹配 (最高相似度: {max_sim:.3f}, 阈值: {self.reid_threshold:.3f})")
         
        # 打印详细统计信息
        self._print_matching_statistics(matches, n_current, n_history, similarity_matrix)
         
        return matches, current_features, history_features
     
    def _print_matching_statistics(self, matches, n_current, n_history, similarity_matrix):
        """打印匹配统计信息"""
        matched_current = set(i for i, j, sim in matches)
        matched_history = set(j for i, j, sim in matches)
         
        print(f"\n" + "="*50)
        print(f"🎯 多对一匹配最终结果统计")
        print(f"="*50)
        print(f"当前帧目标总数: {n_current}")
        print(f"历史帧目标总数: {n_history}")
        print(f"成功匹配对数: {len(matches)}")
        print(f"已匹配的当前目标: {len(matched_current)} / {n_current} ({len(matched_current)/n_current*100:.1f}%)")
        print(f"已匹配的历史目标: {len(matched_history)} / {n_history} ({len(matched_history)/n_history*100:.1f}%)")
         
        # 相似度统计
        if matches:
            similarities = [sim for _, _, sim in matches]
            avg_similarity = np.mean(similarities)
            max_similarity = np.max(similarities)
            min_similarity = np.min(similarities)
            print(f"匹配相似度统计:")
            print(f"  平均相似度: {avg_similarity:.3f}")
            print(f"  最高相似度: {max_similarity:.3f}")
            print(f"  最低相似度: {min_similarity:.3f}")
            print(f"  阈值设置: {self.reid_threshold:.3f}")
         
        # 详细匹配结果
        print(f"\n📋 详细匹配结果:")
        for idx, (curr_idx, hist_idx, similarity) in enumerate(matches):
            print(f"  匹配{idx}: 当前帧目标{curr_idx} → 历史帧目标{hist_idx}, 相似度: {similarity:.3f}")
         
        # 未匹配分析
        unmatched_current = n_current - len(matched_current)
        unmatched_history = n_history - len(matched_history)
         
        if unmatched_current > 0:
            print(f"\n❌ 未匹配的当前目标分析:")
            for i in range(n_current):
                if i not in matched_current:
                    max_sim = np.max(similarity_matrix[i]) if n_history > 0 else 0
                    threshold_diff = max_sim - self.reid_threshold
                    status = "超过阈值" if max_sim >= self.reid_threshold else "低于阈值"
                    print(f"  当前帧目标{i}: 最高相似度 {max_sim:.3f} ({status}, 差値: {threshold_diff:+.3f})")
         
        if unmatched_history > 0:
            print(f"\n❌ 未匹配的历史目标分析:")
            for j in range(n_history):
                if j not in matched_history:
                    max_sim = np.max(similarity_matrix[:, j]) if n_current > 0 else 0
                    threshold_diff = max_sim - self.reid_threshold
                    status = "超过阈值" if max_sim >= self.reid_threshold else "低于阈值"
                    print(f"  历史帧目标{j}: 最高相似度 {max_sim:.3f} ({status}, 差値: {threshold_diff:+.3f})")
         
        print(f"="*50)
 
    def get_match_color(self, match_id):
        """为匹配ID生成颜色"""
        if match_id not in self.match_colors:
            hue = (match_id * 0.618033988749895) % 1.0
            r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
            self.match_colors[match_id] = (int(r * 255), int(g * 255), int(b * 255))
        return self.match_colors[match_id]
     
    def get_target_color(self, target_id, is_matched=True):
        """为目标ID生成颜色,匹配的目标用彩色,未匹配的用灰色"""
        if is_matched:
            if target_id not in self.target_colors:
                hue = (target_id * 0.618033988749895) % 1.0
                r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
                self.target_colors[target_id] = (int(r * 255), int(g * 255), int(b * 255))
            return self.target_colors[target_id]
        else:
            # 未匹配的目标用灰色
            return (128, 128, 128)
 
def load_model(model_path="sam3.pt"):
    """加载模型"""
    model_load_start_time = time.time()
     
    model = build_sam3_image_model(
        checkpoint_path=model_path
    )
    processor = Sam3Processor(model, confidence_threshold=0.5)
     
    model_load_end_time = time.time()
    model_load_time = model_load_end_time - model_load_start_time
    print(f"模型加载时间: {model_load_time:.3f} 秒")
     
    return processor
 
def get_image_mask(processor, image_path):
    """获取图像分割结果"""
    detection_start_time = time.time()
 
    image = Image.open(image_path)
     
    # 确保图像是RGB模式
    if image.mode != 'RGB':
        image = image.convert('RGB')
        print(f"图像已转换为RGB模式: {image.mode}")
     
    inference_state = processor.set_image(image)
     
    output = processor.set_text_prompt(state=inference_state, prompt="building")
    masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
     
    detection_end_time = time.time()
    detection_time = detection_end_time - detection_start_time
    print(f"检测单张时间: {detection_time:.3f} 秒")
    print(f"原始检测到 {len(masks)} 个分割结果")
    print(f"掩码形状: {masks.shape}")
     
    return masks, boxes, scores, image
 
def calculate_iou(box1, box2):
    """计算两个边界框的IoU"""
    x1_1, y1_1, x1_2, y1_2 = box1
    x2_1, y2_1, x2_2, y2_2 = box2
     
    xi1 = max(x1_1, x2_1)
    yi1 = max(y1_1, y2_1)
    xi2 = min(x1_2, x2_2)
    yi2 = min(y1_2, y2_2)
     
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
    box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
    union_area = box1_area + box2_area - inter_area
     
    if union_area == 0:
        return 0.0
     
    iou_2 = inter_area / box2_area
    iou_1 = inter_area / box1_area
    iou = max(iou_2, iou_1)
    return iou
 
def calculate_mask_overlap(mask1, mask2):
    """计算两个掩码的重叠比例"""
    mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
    mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
     
    intersection = np.logical_and(mask1_np, mask2_np)
    mask1_area = np.sum(mask1_np)
     
    if mask1_area == 0:
        return 0.0
     
    overlap_ratio = np.sum(intersection) / mask1_area
    return overlap_ratio
 
def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.3, overlap_threshold=0.6):
    """融合重叠的掩码和边界框"""
    if len(masks) == 0:
        return masks, boxes, scores
     
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
     
    areas = []
    for box in boxes_np:
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        area = width * height
        areas.append(area)
     
    areas_np = np.array(areas)
    sorted_indices = np.argsort(areas_np)[::-1]
     
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    areas_sorted = areas_np[sorted_indices]
     
    masks_list = [masks[i] for i in range(len(masks))]
    masks_sorted = [masks_list[i] for i in sorted_indices]
     
    keep_indices = []
    suppressed = set()
    fused_masks = masks_sorted.copy()
     
    for i in range(len(boxes_sorted)):
        if i in suppressed:
            continue
         
        keep_indices.append(i)
        current_mask = fused_masks[i]
         
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                continue
             
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
             
            if iou > iou_threshold:
                mask_overlap = calculate_mask_overlap(current_mask, masks_sorted[j])
                suppressed.add(j)
                 
                fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
                fused_masks[i] = fused_mask
                current_mask = fused_mask
                 
                print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})",
                      " iou:", iou, " mask重叠:", mask_overlap)
                 
    final_indices = [sorted_indices[i] for i in keep_indices]
    final_masks_list = [fused_masks[i] for i in keep_indices]
     
    final_masks = torch.stack(final_masks_list)
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
     
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
     
    return final_masks, final_boxes, final_scores
 
def fuse_two_masks(mask1, mask2):
    """将两个mask融合"""
    fused_mask = torch.logical_or(mask1, mask2).float()
    return fused_mask
 
def create_matched_image(image, masks, boxes, scores, matches, matcher, is_current_frame=True, alpha=0.5):
    """在原图上叠加mask,匹配的目标用彩色,未匹配的用灰色"""
    # 创建图像副本
    img_array = np.array(image).astype(np.float32) / 255.0
    img_height, img_width = img_array.shape[:2]
     
    # 获取匹配的目标索引和匹配ID映射
    if is_current_frame:
        matched_indices = set(idx1 for idx1, _, _ in matches)
        # 创建当前目标到匹配ID的映射
        target_to_match_id = {}
        for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
            target_to_match_id[curr_idx] = match_id
    else:
        matched_indices = set(idx2 for _, idx2, _ in matches)
        # 创建历史目标到匹配ID的映射(一个历史目标可能对应多个匹配ID,取第一个)
        target_to_match_id = {}
        for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
            if hist_idx not in target_to_match_id:  # 只记录第一个匹配ID
                target_to_match_id[hist_idx] = match_id
     
    # 为每个mask创建彩色叠加层
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        is_matched = i in matched_indices
         
        # 获取颜色 - 相同匹配使用相同颜色
        if is_matched and i in target_to_match_id:
            match_id = target_to_match_id[i]
            color = matcher.get_match_color(match_id)  # 使用匹配ID的颜色
            color_array = np.array(color) / 255.0
        elif is_matched:
            # 如果没有匹配ID映射,使用目标ID的颜色
            color = matcher.get_target_color(i, True)
            color_array = np.array(color) / 255.0
        else:
            color_array = np.array([0.5, 0.5, 0.5])  # 灰色
         
        # 将mask转换为numpy数组
        mask_np = mask.cpu().numpy().squeeze()
         
        # 创建彩色mask
        colored_mask = np.zeros_like(img_array)
        for c in range(3):
            colored_mask[:, :, c] = mask_np * color_array[c]
         
        # 叠加到原图
        img_array = img_array * (1 - mask_np[..., None] * alpha) + colored_mask * alpha
     
    # 转换回PIL图像
    result_img = Image.fromarray((img_array * 255).astype(np.uint8))
    draw = ImageDraw.Draw(result_img)
     
    try:
        font = ImageFont.truetype("Arial.ttf", 20)  # 稍微减小字体以适应更多内容
    except:
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            font = ImageFont.load_default()
            print("Using default font, Arial not available")
     
    # 绘制边界框和标签
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        is_matched = i in matched_indices
         
        # 获取颜色 - 与mask相同的颜色
        if is_matched and i in target_to_match_id:
            match_id = target_to_match_id[i]
            color = matcher.get_match_color(match_id)
        elif is_matched:
            color = matcher.get_target_color(i, True)
        else:
            color = (128, 128, 128)  # 灰色
         
        # 绘制边界框
        x1, y1, x2, y2 = box.cpu().numpy()
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(img_width, x2), min(img_height, y2)
         
        if is_matched:
            # 实线边框 - 已匹配目标
            draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
        else:
            # 虚线边框 - 未匹配目标
            dash_length = 6
            gap_length = 3
            # 上边框
            x = x1
            while x < x2:
                draw.line([x, y1, min(x + dash_length, x2), y1], fill=color, width=2)
                x += dash_length + gap_length
            # 下边框
            x = x1
            while x < x2:
                draw.line([x, y2, min(x + dash_length, x2), y2], fill=color, width=2)
                x += dash_length + gap_length
            # 左边框
            y = y1
            while y < y2:
                draw.line([x1, y, x1, min(y + dash_length, y2)], fill=color, width=2)
                y += dash_length + gap_length
            # 右边框
            y = y1
            while y < y2:
                draw.line([x2, y, x2, min(y + dash_length, y2)], fill=color, width=2)
                y += dash_length + gap_length
         
        # 绘制标签 - 全部使用英文
        if is_current_frame:
            # 当前帧标签
            if is_matched:
                # 找到匹配信息
                for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
                    if curr_idx == i:
                        label = f"C{i}->H{hist_idx}\nSim:{similarity:.3f}\nMask:{score:.2f}"
                        break
            else:
                label = f"C{i} Unmatched\nMask:{score:.2f}"
        else:
            # 历史帧标签
            if is_matched:
                # 找到所有匹配到该历史目标的当前目标
                matched_current = []
                match_similarities = []
                for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
                    if hist_idx == i:
                        matched_current.append(curr_idx)
                        match_similarities.append(similarity)
                 
                if matched_current:
                    # 显示匹配数量和最高相似度
                    if len(matched_current) == 1:
                        label = f"H{i}<-C{matched_current[0]}\nSim:{match_similarities[0]:.3f}\nMask:{score:.2f}"
                    else:
                        label = f"H{i}<-{len(matched_current)}C\nBest:{max(match_similarities):.3f}\nMask:{score:.2f}"
                else:
                    label = f"H{i} Matched\nMask:{score:.2f}"
            else:
                label = f"H{i} Unmatched\nMask:{score:.2f}"
         
        # 计算标签大小
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
         
        # 智能标签位置选择
        label_x, label_y = _calculate_optimal_label_position(
            x1, y1, x2, y2, text_width, text_height, img_width, img_height
        )
         
        # 绘制标签背景
        padding = 4
        bg_bbox = [
            label_x - padding,
            label_y - padding,
            label_x + text_width + padding,
            label_y + text_height + padding
        ]
        draw.rectangle(bg_bbox, fill=color)
         
        # 绘制标签文字
        draw.text((label_x, label_y), label, fill="white", font=font)
     
    return result_img
 
def _calculate_optimal_label_position( x1, y1, x2, y2, text_width, text_height, img_width, img_height):
    """
    计算最优的标签位置,确保标签在图像可见区域内
    优先级:右上角 → 左上角 → 右下角 → 左下角 → 框内中心
    """
    padding = 10
    positions = []
     
    # 候选位置(相对于边界框)
    candidates = [
        # 右上角(优先)
        (x2 + padding, y1 - text_height - padding),
        # 左上角
        (x1 - text_width - padding, y1 - text_height - padding),
        # 右下角
        (x2 + padding, y2 + padding),
        # 左下角
        (x1 - text_width - padding, y2 + padding),
        # 框内右上角
        (x2 - text_width - padding, y1 + padding),
        # 框内左上角
        (x1 + padding, y1 + padding),
        # 框内右下角
        (x2 - text_width - padding, y2 - text_height - padding),
        # 框内左下角
        (x1 + padding, y2 - text_height - padding)
    ]
     
    # 评估每个候选位置
    for candidate_x, candidate_y in candidates:
        # 检查是否在图像范围内
        if (0 <= candidate_x <= img_width - text_width and
            0 <= candidate_y <= img_height - text_height):
            # 计算与边界框的距离(优先选择靠近边界框的位置)
            distance = min(
                abs(candidate_x - x1), abs(candidate_x - x2),
                abs(candidate_y - y1), abs(candidate_y - y2)
            )
            positions.append(((candidate_x, candidate_y), distance))
     
    # 如果有合适的位置,选择距离最近的一个
    if positions:
        positions.sort(key=lambda x: x[1])  # 按距离排序
        return positions[0][0]
     
    # 如果没有合适的外部位置,使用框内中心位置(确保可见)
    center_x = (x1 + x2) / 2 - text_width / 2
    center_y = (y1 + y2) / 2 - text_height / 2
     
    # 确保在图像范围内
    center_x = max(0, min(center_x, img_width - text_width))
    center_y = max(0, min(center_y, img_height - text_height))
     
    return center_x, center_y
 
# 同时优化原始图像和融合图像的标签位置函数
def create_original_image_with_boxes(image, masks, boxes, scores, title):
    """在原图上绘制边界框和标签,每个目标用不同颜色"""
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)
    img_width, img_height = img_copy.size
     
    try:
        font = ImageFont.truetype("Arial.ttf", 20)
    except:
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            font = ImageFont.load_default()
     
    # 为每个目标生成不同颜色
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        # 为每个目标生成独特的颜色
        hue = (i * 0.618033988749895) % 1.0
        r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
        color = (int(r * 255), int(g * 255), int(b * 255))
         
        # 绘制边界框
        x1, y1, x2, y2 = box.cpu().numpy()
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(img_width, x2), min(img_height, y2)
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
         
        # 绘制标签
        label = f"ID:{i} Score:{score:.2f}"
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
         
        # 智能选择标签位置
        label_x, label_y = _calculate_optimal_label_position(
            x1, y1, x2, y2, text_width, text_height, img_width, img_height
        )
         
        # 绘制标签背景
        padding = 4
        bg_bbox = [
            label_x - padding,
            label_y - padding,
            label_x + text_width + padding,
            label_y + text_height + padding
        ]
        draw.rectangle(bg_bbox, fill=color)
         
        # 绘制标签文字
        draw.text((label_x, label_y), label, fill="white", font=font)
     
    return img_copy
 
def create_fused_image_with_boxes(image, masks, boxes, scores, title):
    """在原图上绘制边界框和标签,每个目标用不同颜色,并添加质心、方向和外接圆"""
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)
    img_width, img_height = img_copy.size
    
    try:
        font = ImageFont.truetype("Arial.ttf", 20)
    except:
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            font = ImageFont.load_default()
    
    # 为每个目标生成不同颜色
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        # 为每个目标生成独特的颜色
        hue = (i * 0.618033988749895) % 1.0
        r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
        color = (int(r * 255), int(g * 255), int(b * 255))
        
        # 绘制边界框
        x1, y1, x2, y2 = box.cpu().numpy()
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(img_width, x2), min(img_height, y2)
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
        
        # 将mask转换为numpy数组
        mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
        
        # 1. 计算并绘制质心(红色点)
        moments = cv2.moments(mask_np)
        if moments["m00"] != 0:
            cx = int(moments["m10"] / moments["m00"])
            cy = int(moments["m01"] / moments["m00"])
            # 绘制红色质心点
            draw.ellipse([cx-3, cy-3, cx+3, cy+3], fill=(255, 0, 0), outline=(0, 0, 0))
        
        # 2. 计算并绘制方向(蓝色线)
        if mask_np.sum() > 0:
            # 计算PCA获取主方向
            y_coords, x_coords = np.where(mask_np > 0)
            coords = np.column_stack((x_coords, y_coords))
            mean = np.mean(coords, axis=0)
            cov = np.cov(coords.T)
            eigvals, eigvecs = np.linalg.eig(cov)
            # 获取主方向向量
            main_dir = eigvecs[:, np.argmax(eigvals)]
            # 计算线的起点和终点
            start_point = (int(mean[0]), int(mean[1]))
            end_point = (int(mean[0] + main_dir[0] * 30), int(mean[1] + main_dir[1] * 30))
            # 绘制蓝色方向线
            draw.line([start_point, end_point], fill=(0, 0, 255), width=2)
        
        # 3. 计算并绘制最小外接圆(黄色圆)
        if mask_np.sum() > 0:
            contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                # 获取最大轮廓
                max_contour = max(contours, key=cv2.contourArea)
                # 计算最小外接圆
                (x, y), radius = cv2.minEnclosingCircle(max_contour)
                center = (int(x), int(y))
                radius = int(radius)
                # 绘制黄色外接圆
                draw.ellipse([center[0]-radius, center[1]-radius, 
                             center[0]+radius, center[1]+radius], 
                            outline=(255, 255, 0), width=2)
        
        # 绘制标签
        label = f"ID:{i} Score:{score:.2f}"
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        
        # 智能选择标签位置
        label_x, label_y = _calculate_optimal_label_position(
            x1, y1, x2, y2, text_width, text_height, img_width, img_height
        )
        
        # 绘制标签背景
        padding = 4
        bg_bbox = [
            label_x - padding,
            label_y - padding,
            label_x + text_width + padding,
            label_y + text_height + padding
        ]
        draw.rectangle(bg_bbox, fill=color)
        
        # 绘制标签文字
        draw.text((label_x, label_y), label, fill="white", font=font)
    
    return img_copy

 
def visualize_results(current_image, current_masks, current_boxes, current_scores,
                     history_image, history_masks, history_boxes, history_scores,
                     current_fused_masks, current_fused_boxes, current_fused_scores,
                     history_fused_masks, history_fused_boxes, history_fused_scores,
                     matches, current_features, history_features, output_dir="output"):
    """可视化显示所有结果"""
    os.makedirs(output_dir, exist_ok=True)
     
    # 创建匹配器用于颜色管理
    matcher = ImageMatcher()
     
    # 创建大图显示
    fig, axes = plt.subplots(3, 2, figsize=(24, 28))
     
    # 第一行:原始分割结果(只显示边界框)
    img1_original = create_original_image_with_boxes(current_image, current_masks, current_boxes, current_scores, "Original Segmentation")
    img2_original = create_original_image_with_boxes(history_image, history_masks, history_boxes, history_scores, "Original Segmentation")
     
    # 确保图像是RGB模式
    if img1_original.mode != 'RGB':
        img1_original = img1_original.convert('RGB')
    if img2_original.mode != 'RGB':
        img2_original = img2_original.convert('RGB')
     
    axes[0, 0].imshow(img1_original)
    axes[0, 0].set_title(f"Current Frame - Original ({len(current_masks)} targets)", fontsize=16, fontweight='bold')
    axes[0, 0].axis('off')
     
    axes[0, 1].imshow(img2_original)
    axes[0, 1].set_title(f"History Frame - Original ({len(history_masks)} targets)", fontsize=16, fontweight='bold')
    axes[0, 1].axis('off')
     
    # 第二行:融合后的分割结果(只显示边界框)
    img1_fused = create_fused_image_with_boxes(current_image, current_fused_masks, current_fused_boxes, current_fused_scores, "Fused Results")
    img2_fused = create_fused_image_with_boxes(history_image, history_fused_masks, history_fused_boxes, history_fused_scores, "Fused Results")
     
    if img1_fused.mode != 'RGB':
        img1_fused = img1_fused.convert('RGB')
    if img2_fused.mode != 'RGB':
        img2_fused = img2_fused.convert('RGB')
     
    axes[1, 0].imshow(img1_fused)
    axes[1, 0].set_title(f"Current Frame - Fused ({len(current_fused_masks)} targets)", fontsize=16, fontweight='bold')
    axes[1, 0].axis('off')
     
    axes[1, 1].imshow(img2_fused)
    axes[1, 1].set_title(f"History Frame - Fused ({len(history_fused_masks)} targets)", fontsize=16, fontweight='bold')
    axes[1, 1].axis('off')
     
    # 第三行:匹配结果(显示mask和边界框)
    img1_matched = create_matched_image(current_image, current_fused_masks, current_fused_boxes, current_fused_scores,
                                       matches, matcher, is_current_frame=True, alpha=0.5)
    img2_matched = create_matched_image(history_image, history_fused_masks, history_fused_boxes, history_fused_scores,
                                      matches, matcher, is_current_frame=False, alpha=0.5)
     
    if img1_matched.mode != 'RGB':
        img1_matched = img1_matched.convert('RGB')
    if img2_matched.mode != 'RGB':
        img2_matched = img2_matched.convert('RGB')
     
    axes[2, 0].imshow(img1_matched)
    axes[2, 0].set_title(f"Current Frame - Matching Results", fontsize=16, fontweight='bold')
    axes[2, 0].axis('off')
     
    axes[2, 1].imshow(img2_matched)
    axes[2, 1].set_title(f"History Frame - Matching Results ({len(matches)} matches)", fontsize=16, fontweight='bold')
    axes[2, 1].axis('off')
     
    # 添加图例说明(英文)
    fig.text(0.02, 0.02,
             "Label Legend:\n"
             "C{i}: Current frame target ID, H{i}: History frame target ID\n"
             "Sim: Feature similarity score, Mask: Segmentation confidence\n"
             "->: Match direction, Solid box: Matched, Dashed box: Unmatched\n"
             "Same color indicates same match pair",
             fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
     
    plt.tight_layout()
     
    # 保存结果
    timestamp = int(time.time())
    result_path = os.path.join(output_dir, f"multi_to_one_matching_result_{timestamp}.png")
    plt.savefig(result_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Matching results saved: {result_path}")
     
    # 保存特征数据
    save_features(current_fused_masks, current_fused_boxes, current_features, current_image, output_dir, "current_frame")
    save_features(history_fused_masks, history_fused_boxes, history_features, history_image, output_dir, "history_frame")
     
    plt.show()
     
    return result_path
 
 
def save_features(masks, boxes, features, image, output_dir, image_name):
    """保存特征数据"""
    feature_data = {
        'image_name': image_name,
        'num_targets': len(masks),
        'features': [f.tolist() if f is not None else None for f in features],
        'boxes': [box.cpu().numpy().tolist() for box in boxes],
        'mask_shapes': [mask.shape for mask in masks],
        'timestamp': time.time()
    }
     
    feature_path = os.path.join(output_dir, f"{image_name}_features.json")
    with open(feature_path, 'w') as f:
        json.dump(feature_data, f, indent=2)
     
    print(f"{image_name} 特征已保存: {feature_path}")
 
def process_relocalization_with_multi_to_one(processor, current_image_path, history_image_path, output_dir="output"):
    """使用多对一匹配策略进行重定位处理"""
    os.makedirs(output_dir, exist_ok=True)
     
    print("=" * 60)
    print("🚀 开始多对一重定位匹配处理")
    print("=" * 60)
     
    # 处理当前帧
    print("📷 处理当前帧...")
    current_masks, current_boxes, current_scores, current_image = get_image_mask(processor, current_image_path)
     
    print("📷 处理历史帧...")
    history_masks, history_boxes, history_scores, history_image = get_image_mask(processor, history_image_path)
     
    print("=" * 60)
    print("🔄 开始融合重叠掩码...")
     
    # 融合重叠的mask
    current_fused_masks, current_fused_boxes, current_fused_scores = fuse_overlapping_masks(
        current_masks, current_boxes, current_scores, iou_threshold=0.5, overlap_threshold=0.6
    )
     
    history_fused_masks, history_fused_boxes, history_fused_scores = fuse_overlapping_masks(
        history_masks, history_boxes, history_scores, iou_threshold=0.5, overlap_threshold=0.6
    )
     
    print("=" * 60)
    print("🎯 开始多对一重定位匹配...")
     
    # 创建匹配器并使用多对一匹配策略
    matcher = ImageMatcher(reid_threshold=0.7, matching_strategy="multi_to_one")
     
    # 执行多对一匹配
    matches, current_features, history_features = matcher.match_images_multi_to_one(
        current_image=current_image,
        current_masks=current_fused_masks,
        current_boxes=current_fused_boxes,
        history_image=history_image,
        history_masks=history_fused_masks,
        history_boxes=history_fused_boxes
    )
     
    print("=" * 60)
    print("📊 开始可视化结果...")
     
    # 可视化结果
    result_path = visualize_results(
        current_image, current_masks, current_boxes, current_scores,
        history_image, history_masks, history_boxes, history_scores,
        current_fused_masks, current_fused_boxes, current_fused_scores,
        history_fused_masks, history_fused_boxes, history_fused_scores,
        matches, current_features, history_features, output_dir
    )
     
    # 清理内存
    del current_masks, current_boxes, current_scores, history_masks, history_boxes, history_scores
    del current_fused_masks, current_fused_boxes, current_fused_scores
    del history_fused_masks, history_fused_boxes, history_fused_scores
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
     
    return {
        'matches': matches,
        'current_features': current_features,
        'history_features': history_features,
        'result_path': result_path
    }
 
def main():
    """主函数"""
    # 清空GPU缓存
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
 
    processor = load_model("sam3.pt")
     
    # 设置模型为评估模式
    processor.model.eval()
 
    # 示例图像路径 - 请修改为您的实际图像路径
    current_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0952.JPG"
    history_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0966.JPG"
    output_dir = "multi_to_one_matching_output"
     
    try:
        results = process_relocalization_with_multi_to_one(processor, current_image_path, history_image_path, output_dir)
         
        print(f"\n🎉 多对一重定位匹配处理完成!")
        print(f"📁 匹配结果保存在: {results['result_path']}")
        print(f"🔗 找到 {len(results['matches'])} 个匹配")
         
        # 显示匹配详情
        print(f"\n📋 匹配详情:")
        for i, (current_idx, history_idx, similarity) in enumerate(results['matches']):
            print(f"  匹配{i}: 当前帧目标{current_idx} → 历史帧目标{history_idx}, 相似度: {similarity:.3f}")
         
    except Exception as e:
        print(f"❌ 处理过程中出错: {e}")
        import traceback
        traceback.print_exc()
    finally:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
 
if __name__ == "__main__":
    main()

  

 

 原始代码

'''
 
主要修改内容:
1. 匹配策略改进 (match_images_multi_to_one)
修改前:复杂的竞争解决机制
修改后:简单直接的多对一匹配
每个当前帧目标从所有超过阈值的历史候选匹配中选择相似度最高的
允许多个当前目标匹配到同一个历史目标
简化了匹配逻辑,提高了匹配效率
2. 可视化标签改进 (create_matched_image)
当前帧标签显示:
C{i}→H{j}:当前帧目标i匹配到历史帧目标j
Sim:{similarity:.3f}:特征匹配相似度
Mask:{score:.2f}:mask分割置信度
历史帧标签显示:
H{i}←C{j1}(sim1),C{j2}(sim2):历史帧目标i被多个当前目标匹配
显示最多2个匹配的当前目标及相似度
3. 可视化布局优化
增加图像尺寸以适应更多信息显示
添加图例说明标签含义
减小字体大小以显示完整信息
改进颜色编码和边框样式
4. 匹配逻辑优势
新的匹配策略更加符合多对一的场景需求:
一个历史建筑可能被多个当前视角的建筑匹配
每个当前目标独立选择最佳匹配
避免了复杂的冲突解决,提高匹配成功率
这样的修改使得匹配过程更加直观,可视化结果更加清晰,能够准确显示每个目标的匹配状态和关键指标。
 
. 图例说明
使用英文说明标签含义
说明颜色编码和边框样式含义
增加图例背景提高可读性
6. 可视化效果
颜色一致性:相同匹配对在两帧中显示相同颜色
清晰区分:实线边框(已匹配)vs 虚线边框(未匹配)
信息完整:显示目标ID、匹配关系、相似度、分割分数
可读性强:大字体、英文标签、清晰的视觉层次
 
 
 
 
1 多个图片看到的特征融合
2 保存和加载
3 视角差异
方向度对齐 最小矩形框  选择方向 旋转 然后送入特征检测
平移
尺度对齐
4 阈值过滤
 
'''
 
 
 
 
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time
import os
import glob
from scipy.optimize import linear_sum_assignment
import colorsys
import gc
import matplotlib
import json
 
matplotlib.use('TkAgg')
 
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="tkinter")
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")
 
# 新增导入
import cv2
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
 
# 内存优化设置
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
 
#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
 
class ReIDNetwork(nn.Module):
    """ReID网络用于提取目标外观特征"""
     
    def __init__(self, feature_dim=512):
        super(ReIDNetwork, self).__init__()
        # 使用预训练的ResNet作为骨干网络
        self.backbone = models.resnet50(pretrained=True)
        # 移除分类层
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
         
        # 全局平均池化
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
         
        # 特征降维
        self.feature_reduction = nn.Sequential(
            nn.Linear(2048, feature_dim),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.BatchNorm1d(feature_dim // 2),
            nn.ReLU(inplace=True),
        )
         
        self.feature_dim = feature_dim // 2
         
        # 图像预处理
        self.transform = transforms.Compose([
            transforms.Resize((256, 128)),  # ReID标准尺寸
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
     
    def forward(self, x):
        """前向传播"""
        features = self.backbone(x)
        features = self.global_avg_pool(features)
        features = features.view(features.size(0), -1)
        features = self.feature_reduction(features)
        # L2归一化
        features = nn.functional.normalize(features, p=2, dim=1)
        return features
     
    def extract_features_from_crop(self, image, box):
        """从图像裁剪中提取特征"""
        try:
            # 裁剪目标区域
            x1, y1, x2, y2 = map(int, box)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(image.width, x2)
            y2 = min(image.height, y2)
             
            if x2 <= x1 or y2 <= y1:
                return None
                 
            crop = image.crop((x1, y1, x2, y2))
             
            # 转换为RGB(处理可能的RGBA图像)
            if crop.mode != 'RGB':
                crop = crop.convert('RGB')
             
            # 预处理
            crop_tensor = self.transform(crop).unsqueeze(0)
             
            # 提取特征
            with torch.no_grad():
                features = self.forward(crop_tensor)
             
            return features.squeeze(0).cpu().numpy()
             
        except Exception as e:
            print(f"ReID特征提取错误: {e}")
            return None
 
    def extract_features_from_mask(self, image, mask):
        """从mask区域提取特征"""
        try:
            # 将mask转换为numpy数组
            mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)
             
            # 找到mask的边界框
            rows = np.any(mask_np, axis=1)
            cols = np.any(mask_np, axis=0)
             
            if not np.any(rows) or not np.any(cols):
                return None
                 
            y1, y2 = np.where(rows)[0][[0, -1]]
            x1, x2 = np.where(cols)[0][[0, -1]]
             
            # 扩展边界框
            padding = 5
            x1 = max(0, x1 - padding)
            y1 = max(0, y1 - padding)
            x2 = min(image.width, x2 + padding)
            y2 = min(image.height, y2 + padding)
             
            if x2 <= x1 or y2 <= y1:
                return None
                 
            # 裁剪目标区域
            crop = image.crop((x1, y1, x2, y2))
             
            # 转换为RGB
            if crop.mode != 'RGB':
                crop = crop.convert('RGB')
             
            # 预处理
            crop_tensor = self.transform(crop).unsqueeze(0)
             
            # 提取特征
            with torch.no_grad():
                features = self.forward(crop_tensor)
             
            return features.squeeze(0).cpu().numpy()
             
        except Exception as e:
            print(f"从mask提取ReID特征错误: {e}")
            return None
 
class ImageMatcher:
    """图像匹配器,支持多对一匹配策略"""
     
    def __init__(self, reid_threshold=0.7, matching_strategy="multi_to_one"):
        self.reid_threshold = reid_threshold
        self.matching_strategy = matching_strategy  # "multi_to_one", "one_to_one", "greedy"
        self.reid_net = ReIDNetwork()
        self.reid_net.eval()
        self.match_colors = {}
        self.target_colors = {}
     
    def extract_features_for_image(self, image, masks, boxes):
        """为图像的所有分割结果提取特征"""
        features = []
        for mask, box in zip(masks, boxes):
            # 优先使用mask区域提取特征
            feature = self.reid_net.extract_features_from_mask(image, mask)
            if feature is None:
                # 如果mask提取失败,使用边界框
                feature = self.reid_net.extract_features_from_crop(image, box)
            features.append(feature)
        return features
     
    def calculate_cosine_similarity(self, feat1, feat2):
        """计算两个特征向量的余弦相似度"""
        if feat1 is None or feat2 is None:
            return 0.0
        return np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2) + 1e-8)
     
    def match_images_multi_to_one(self, current_image, current_masks, current_boxes,
                                history_image, history_masks, history_boxes):
        """多对一匹配:当前帧目标从所有超过阈值的历史候选匹配选取最高匹配作为匹配结果"""
         
        # 提取特征
        current_features = self.extract_features_for_image(current_image, current_masks, current_boxes)
        history_features = self.extract_features_for_image(history_image, history_masks, history_boxes)
         
        n_current = len(current_features)
        n_history = len(history_features)
         
        if n_current == 0 or n_history == 0:
            return [], current_features, history_features
         
        # 构建相似度矩阵
        similarity_matrix = np.zeros((n_current, n_history))
        for i in range(n_current):
            for j in range(n_history):
                similarity = self.calculate_cosine_similarity(current_features[i], history_features[j])
                similarity_matrix[i, j] = similarity
         
        print(f"相似度矩阵构建完成: {n_current} x {n_history}")
         
        # 新的匹配策略:为每个当前帧目标选择相似度最高的历史帧目标(超过阈值)
        matches = []
        used_history_indices = set()  # 记录已匹配的历史帧目标
         
        for i in range(n_current):
            # 获取当前目标i与所有历史目标的相似度
            similarities = similarity_matrix[i, :]
             
            # 找到超过阈值的最佳匹配
            best_match_idx = -1
            best_similarity = 0.0
             
            for j in range(n_history):
                if similarities[j] >= self.reid_threshold and similarities[j] > best_similarity:
                    # 检查历史目标是否已被匹配(多对一允许重复匹配)
                    # 这里我们允许历史目标被多个当前目标匹配
                    best_match_idx = j
                    best_similarity = similarities[j]
             
            if best_match_idx != -1:
                matches.append((i, best_match_idx, best_similarity))
                used_history_indices.add(best_match_idx)
                print(f"✓ 匹配成功: 当前帧目标{i} → 历史帧目标{best_match_idx}, 相似度: {best_similarity:.3f}")
            else:
                max_sim = np.max(similarities) if n_history > 0 else 0
                print(f"❌ 当前帧目标{i}: 无超过阈值的历史匹配 (最高相似度: {max_sim:.3f}, 阈值: {self.reid_threshold:.3f})")
         
        # 打印详细统计信息
        self._print_matching_statistics(matches, n_current, n_history, similarity_matrix)
         
        return matches, current_features, history_features
     
    def _print_matching_statistics(self, matches, n_current, n_history, similarity_matrix):
        """打印匹配统计信息"""
        matched_current = set(i for i, j, sim in matches)
        matched_history = set(j for i, j, sim in matches)
         
        print(f"\n" + "="*50)
        print(f"🎯 多对一匹配最终结果统计")
        print(f"="*50)
        print(f"当前帧目标总数: {n_current}")
        print(f"历史帧目标总数: {n_history}")
        print(f"成功匹配对数: {len(matches)}")
        print(f"已匹配的当前目标: {len(matched_current)} / {n_current} ({len(matched_current)/n_current*100:.1f}%)")
        print(f"已匹配的历史目标: {len(matched_history)} / {n_history} ({len(matched_history)/n_history*100:.1f}%)")
         
        # 相似度统计
        if matches:
            similarities = [sim for _, _, sim in matches]
            avg_similarity = np.mean(similarities)
            max_similarity = np.max(similarities)
            min_similarity = np.min(similarities)
            print(f"匹配相似度统计:")
            print(f"  平均相似度: {avg_similarity:.3f}")
            print(f"  最高相似度: {max_similarity:.3f}")
            print(f"  最低相似度: {min_similarity:.3f}")
            print(f"  阈值设置: {self.reid_threshold:.3f}")
         
        # 详细匹配结果
        print(f"\n📋 详细匹配结果:")
        for idx, (curr_idx, hist_idx, similarity) in enumerate(matches):
            print(f"  匹配{idx}: 当前帧目标{curr_idx} → 历史帧目标{hist_idx}, 相似度: {similarity:.3f}")
         
        # 未匹配分析
        unmatched_current = n_current - len(matched_current)
        unmatched_history = n_history - len(matched_history)
         
        if unmatched_current > 0:
            print(f"\n❌ 未匹配的当前目标分析:")
            for i in range(n_current):
                if i not in matched_current:
                    max_sim = np.max(similarity_matrix[i]) if n_history > 0 else 0
                    threshold_diff = max_sim - self.reid_threshold
                    status = "超过阈值" if max_sim >= self.reid_threshold else "低于阈值"
                    print(f"  当前帧目标{i}: 最高相似度 {max_sim:.3f} ({status}, 差値: {threshold_diff:+.3f})")
         
        if unmatched_history > 0:
            print(f"\n❌ 未匹配的历史目标分析:")
            for j in range(n_history):
                if j not in matched_history:
                    max_sim = np.max(similarity_matrix[:, j]) if n_current > 0 else 0
                    threshold_diff = max_sim - self.reid_threshold
                    status = "超过阈值" if max_sim >= self.reid_threshold else "低于阈值"
                    print(f"  历史帧目标{j}: 最高相似度 {max_sim:.3f} ({status}, 差値: {threshold_diff:+.3f})")
         
        print(f"="*50)
 
    def get_match_color(self, match_id):
        """为匹配ID生成颜色"""
        if match_id not in self.match_colors:
            hue = (match_id * 0.618033988749895) % 1.0
            r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
            self.match_colors[match_id] = (int(r * 255), int(g * 255), int(b * 255))
        return self.match_colors[match_id]
     
    def get_target_color(self, target_id, is_matched=True):
        """为目标ID生成颜色,匹配的目标用彩色,未匹配的用灰色"""
        if is_matched:
            if target_id not in self.target_colors:
                hue = (target_id * 0.618033988749895) % 1.0
                r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
                self.target_colors[target_id] = (int(r * 255), int(g * 255), int(b * 255))
            return self.target_colors[target_id]
        else:
            # 未匹配的目标用灰色
            return (128, 128, 128)
 
def load_model(model_path="sam3.pt"):
    """加载模型"""
    model_load_start_time = time.time()
     
    model = build_sam3_image_model(
        checkpoint_path=model_path
    )
    processor = Sam3Processor(model, confidence_threshold=0.5)
     
    model_load_end_time = time.time()
    model_load_time = model_load_end_time - model_load_start_time
    print(f"模型加载时间: {model_load_time:.3f} 秒")
     
    return processor
 
def get_image_mask(processor, image_path):
    """获取图像分割结果"""
    detection_start_time = time.time()
 
    image = Image.open(image_path)
     
    # 确保图像是RGB模式
    if image.mode != 'RGB':
        image = image.convert('RGB')
        print(f"图像已转换为RGB模式: {image.mode}")
     
    inference_state = processor.set_image(image)
     
    output = processor.set_text_prompt(state=inference_state, prompt="building")
    masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
     
    detection_end_time = time.time()
    detection_time = detection_end_time - detection_start_time
    print(f"检测单张时间: {detection_time:.3f} 秒")
    print(f"原始检测到 {len(masks)} 个分割结果")
    print(f"掩码形状: {masks.shape}")
     
    return masks, boxes, scores, image
 
def calculate_iou(box1, box2):
    """计算两个边界框的IoU"""
    x1_1, y1_1, x1_2, y1_2 = box1
    x2_1, y2_1, x2_2, y2_2 = box2
     
    xi1 = max(x1_1, x2_1)
    yi1 = max(y1_1, y2_1)
    xi2 = min(x1_2, x2_2)
    yi2 = min(y1_2, y2_2)
     
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)
    box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)
    union_area = box1_area + box2_area - inter_area
     
    if union_area == 0:
        return 0.0
     
    iou_2 = inter_area / box2_area
    iou_1 = inter_area / box1_area
    iou = max(iou_2, iou_1)
    return iou
 
def calculate_mask_overlap(mask1, mask2):
    """计算两个掩码的重叠比例"""
    mask1_np = mask1.cpu().numpy().squeeze().astype(bool)
    mask2_np = mask2.cpu().numpy().squeeze().astype(bool)
     
    intersection = np.logical_and(mask1_np, mask2_np)
    mask1_area = np.sum(mask1_np)
     
    if mask1_area == 0:
        return 0.0
     
    overlap_ratio = np.sum(intersection) / mask1_area
    return overlap_ratio
 
def fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.3, overlap_threshold=0.6):
    """融合重叠的掩码和边界框"""
    if len(masks) == 0:
        return masks, boxes, scores
     
    boxes_np = boxes.cpu().numpy().copy()
    scores_np = scores.cpu().numpy().copy()
     
    areas = []
    for box in boxes_np:
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        area = width * height
        areas.append(area)
     
    areas_np = np.array(areas)
    sorted_indices = np.argsort(areas_np)[::-1]
     
    boxes_sorted = boxes_np[sorted_indices]
    scores_sorted = scores_np[sorted_indices]
    areas_sorted = areas_np[sorted_indices]
     
    masks_list = [masks[i] for i in range(len(masks))]
    masks_sorted = [masks_list[i] for i in sorted_indices]
     
    keep_indices = []
    suppressed = set()
    fused_masks = masks_sorted.copy()
     
    for i in range(len(boxes_sorted)):
        if i in suppressed:
            continue
         
        keep_indices.append(i)
        current_mask = fused_masks[i]
         
        for j in range(i + 1, len(boxes_sorted)):
            if j in suppressed:
                continue
             
            iou = calculate_iou(boxes_sorted[i], boxes_sorted[j])
             
            if iou > iou_threshold:
                mask_overlap = calculate_mask_overlap(current_mask, masks_sorted[j])
                suppressed.add(j)
                 
                fused_mask = fuse_two_masks(current_mask, masks_sorted[j])
                fused_masks[i] = fused_mask
                current_mask = fused_mask
                 
                print(f"融合检测结果: 索引 {sorted_indices[i]} (面积: {areas_sorted[i]:.1f}) 融合索引 {sorted_indices[j]} (面积: {areas_sorted[j]:.1f})",
                      " iou:", iou, " mask重叠:", mask_overlap)
                 
    final_indices = [sorted_indices[i] for i in keep_indices]
    final_masks_list = [fused_masks[i] for i in keep_indices]
     
    final_masks = torch.stack(final_masks_list)
    final_boxes = boxes[final_indices]
    final_scores = scores[final_indices]
     
    print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")
     
    return final_masks, final_boxes, final_scores
 
def fuse_two_masks(mask1, mask2):
    """将两个mask融合"""
    fused_mask = torch.logical_or(mask1, mask2).float()
    return fused_mask
 
def create_matched_image(image, masks, boxes, scores, matches, matcher, is_current_frame=True, alpha=0.5):
    """在原图上叠加mask,匹配的目标用彩色,未匹配的用灰色"""
    # 创建图像副本
    img_array = np.array(image).astype(np.float32) / 255.0
    img_height, img_width = img_array.shape[:2]
     
    # 获取匹配的目标索引和匹配ID映射
    if is_current_frame:
        matched_indices = set(idx1 for idx1, _, _ in matches)
        # 创建当前目标到匹配ID的映射
        target_to_match_id = {}
        for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
            target_to_match_id[curr_idx] = match_id
    else:
        matched_indices = set(idx2 for _, idx2, _ in matches)
        # 创建历史目标到匹配ID的映射(一个历史目标可能对应多个匹配ID,取第一个)
        target_to_match_id = {}
        for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
            if hist_idx not in target_to_match_id:  # 只记录第一个匹配ID
                target_to_match_id[hist_idx] = match_id
     
    # 为每个mask创建彩色叠加层
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        is_matched = i in matched_indices
         
        # 获取颜色 - 相同匹配使用相同颜色
        if is_matched and i in target_to_match_id:
            match_id = target_to_match_id[i]
            color = matcher.get_match_color(match_id)  # 使用匹配ID的颜色
            color_array = np.array(color) / 255.0
        elif is_matched:
            # 如果没有匹配ID映射,使用目标ID的颜色
            color = matcher.get_target_color(i, True)
            color_array = np.array(color) / 255.0
        else:
            color_array = np.array([0.5, 0.5, 0.5])  # 灰色
         
        # 将mask转换为numpy数组
        mask_np = mask.cpu().numpy().squeeze()
         
        # 创建彩色mask
        colored_mask = np.zeros_like(img_array)
        for c in range(3):
            colored_mask[:, :, c] = mask_np * color_array[c]
         
        # 叠加到原图
        img_array = img_array * (1 - mask_np[..., None] * alpha) + colored_mask * alpha
     
    # 转换回PIL图像
    result_img = Image.fromarray((img_array * 255).astype(np.uint8))
    draw = ImageDraw.Draw(result_img)
     
    try:
        font = ImageFont.truetype("Arial.ttf", 20)  # 稍微减小字体以适应更多内容
    except:
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            font = ImageFont.load_default()
            print("Using default font, Arial not available")
     
    # 绘制边界框和标签
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        is_matched = i in matched_indices
         
        # 获取颜色 - 与mask相同的颜色
        if is_matched and i in target_to_match_id:
            match_id = target_to_match_id[i]
            color = matcher.get_match_color(match_id)
        elif is_matched:
            color = matcher.get_target_color(i, True)
        else:
            color = (128, 128, 128)  # 灰色
         
        # 绘制边界框
        x1, y1, x2, y2 = box.cpu().numpy()
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(img_width, x2), min(img_height, y2)
         
        if is_matched:
            # 实线边框 - 已匹配目标
            draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
        else:
            # 虚线边框 - 未匹配目标
            dash_length = 6
            gap_length = 3
            # 上边框
            x = x1
            while x < x2:
                draw.line([x, y1, min(x + dash_length, x2), y1], fill=color, width=2)
                x += dash_length + gap_length
            # 下边框
            x = x1
            while x < x2:
                draw.line([x, y2, min(x + dash_length, x2), y2], fill=color, width=2)
                x += dash_length + gap_length
            # 左边框
            y = y1
            while y < y2:
                draw.line([x1, y, x1, min(y + dash_length, y2)], fill=color, width=2)
                y += dash_length + gap_length
            # 右边框
            y = y1
            while y < y2:
                draw.line([x2, y, x2, min(y + dash_length, y2)], fill=color, width=2)
                y += dash_length + gap_length
         
        # 绘制标签 - 全部使用英文
        if is_current_frame:
            # 当前帧标签
            if is_matched:
                # 找到匹配信息
                for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
                    if curr_idx == i:
                        label = f"C{i}->H{hist_idx}\nSim:{similarity:.3f}\nMask:{score:.2f}"
                        break
            else:
                label = f"C{i} Unmatched\nMask:{score:.2f}"
        else:
            # 历史帧标签
            if is_matched:
                # 找到所有匹配到该历史目标的当前目标
                matched_current = []
                match_similarities = []
                for match_id, (curr_idx, hist_idx, similarity) in enumerate(matches):
                    if hist_idx == i:
                        matched_current.append(curr_idx)
                        match_similarities.append(similarity)
                 
                if matched_current:
                    # 显示匹配数量和最高相似度
                    if len(matched_current) == 1:
                        label = f"H{i}<-C{matched_current[0]}\nSim:{match_similarities[0]:.3f}\nMask:{score:.2f}"
                    else:
                        label = f"H{i}<-{len(matched_current)}C\nBest:{max(match_similarities):.3f}\nMask:{score:.2f}"
                else:
                    label = f"H{i} Matched\nMask:{score:.2f}"
            else:
                label = f"H{i} Unmatched\nMask:{score:.2f}"
         
        # 计算标签大小
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
         
        # 智能标签位置选择
        label_x, label_y = _calculate_optimal_label_position(
            x1, y1, x2, y2, text_width, text_height, img_width, img_height
        )
         
        # 绘制标签背景
        padding = 4
        bg_bbox = [
            label_x - padding,
            label_y - padding,
            label_x + text_width + padding,
            label_y + text_height + padding
        ]
        draw.rectangle(bg_bbox, fill=color)
         
        # 绘制标签文字
        draw.text((label_x, label_y), label, fill="white", font=font)
     
    return result_img
 
def _calculate_optimal_label_position( x1, y1, x2, y2, text_width, text_height, img_width, img_height):
    """
    计算最优的标签位置,确保标签在图像可见区域内
    优先级:右上角 → 左上角 → 右下角 → 左下角 → 框内中心
    """
    padding = 10
    positions = []
     
    # 候选位置(相对于边界框)
    candidates = [
        # 右上角(优先)
        (x2 + padding, y1 - text_height - padding),
        # 左上角
        (x1 - text_width - padding, y1 - text_height - padding),
        # 右下角
        (x2 + padding, y2 + padding),
        # 左下角
        (x1 - text_width - padding, y2 + padding),
        # 框内右上角
        (x2 - text_width - padding, y1 + padding),
        # 框内左上角
        (x1 + padding, y1 + padding),
        # 框内右下角
        (x2 - text_width - padding, y2 - text_height - padding),
        # 框内左下角
        (x1 + padding, y2 - text_height - padding)
    ]
     
    # 评估每个候选位置
    for candidate_x, candidate_y in candidates:
        # 检查是否在图像范围内
        if (0 <= candidate_x <= img_width - text_width and
            0 <= candidate_y <= img_height - text_height):
            # 计算与边界框的距离(优先选择靠近边界框的位置)
            distance = min(
                abs(candidate_x - x1), abs(candidate_x - x2),
                abs(candidate_y - y1), abs(candidate_y - y2)
            )
            positions.append(((candidate_x, candidate_y), distance))
     
    # 如果有合适的位置,选择距离最近的一个
    if positions:
        positions.sort(key=lambda x: x[1])  # 按距离排序
        return positions[0][0]
     
    # 如果没有合适的外部位置,使用框内中心位置(确保可见)
    center_x = (x1 + x2) / 2 - text_width / 2
    center_y = (y1 + y2) / 2 - text_height / 2
     
    # 确保在图像范围内
    center_x = max(0, min(center_x, img_width - text_width))
    center_y = max(0, min(center_y, img_height - text_height))
     
    return center_x, center_y
 
# 同时优化原始图像和融合图像的标签位置函数
def create_original_image_with_boxes(image, masks, boxes, scores, title):
    """在原图上绘制边界框和标签,每个目标用不同颜色"""
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)
    img_width, img_height = img_copy.size
     
    try:
        font = ImageFont.truetype("Arial.ttf", 20)
    except:
        try:
            font = ImageFont.truetype("arial.ttf", 20)
        except:
            font = ImageFont.load_default()
     
    # 为每个目标生成不同颜色
    for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
        # 为每个目标生成独特的颜色
        hue = (i * 0.618033988749895) % 1.0
        r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.8)
        color = (int(r * 255), int(g * 255), int(b * 255))
         
        # 绘制边界框
        x1, y1, x2, y2 = box.cpu().numpy()
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(img_width, x2), min(img_height, y2)
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
         
        # 绘制标签
        label = f"ID:{i} Score:{score:.2f}"
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
         
        # 智能选择标签位置
        label_x, label_y = _calculate_optimal_label_position(
            x1, y1, x2, y2, text_width, text_height, img_width, img_height
        )
         
        # 绘制标签背景
        padding = 4
        bg_bbox = [
            label_x - padding,
            label_y - padding,
            label_x + text_width + padding,
            label_y + text_height + padding
        ]
        draw.rectangle(bg_bbox, fill=color)
         
        # 绘制标签文字
        draw.text((label_x, label_y), label, fill="white", font=font)
     
    return img_copy
 
def create_fused_image_with_boxes(image, masks, boxes, scores, title):
    """在原图上绘制融合后的边界框和标签,每个目标用不同颜色"""
    return create_original_image_with_boxes(image, masks, boxes, scores, title)
 
def visualize_results(current_image, current_masks, current_boxes, current_scores,
                     history_image, history_masks, history_boxes, history_scores,
                     current_fused_masks, current_fused_boxes, current_fused_scores,
                     history_fused_masks, history_fused_boxes, history_fused_scores,
                     matches, current_features, history_features, output_dir="output"):
    """可视化显示所有结果"""
    os.makedirs(output_dir, exist_ok=True)
     
    # 创建匹配器用于颜色管理
    matcher = ImageMatcher()
     
    # 创建大图显示
    fig, axes = plt.subplots(3, 2, figsize=(24, 28))
     
    # 第一行:原始分割结果(只显示边界框)
    img1_original = create_original_image_with_boxes(current_image, current_masks, current_boxes, current_scores, "Original Segmentation")
    img2_original = create_original_image_with_boxes(history_image, history_masks, history_boxes, history_scores, "Original Segmentation")
     
    # 确保图像是RGB模式
    if img1_original.mode != 'RGB':
        img1_original = img1_original.convert('RGB')
    if img2_original.mode != 'RGB':
        img2_original = img2_original.convert('RGB')
     
    axes[0, 0].imshow(img1_original)
    axes[0, 0].set_title(f"Current Frame - Original ({len(current_masks)} targets)", fontsize=16, fontweight='bold')
    axes[0, 0].axis('off')
     
    axes[0, 1].imshow(img2_original)
    axes[0, 1].set_title(f"History Frame - Original ({len(history_masks)} targets)", fontsize=16, fontweight='bold')
    axes[0, 1].axis('off')
     
    # 第二行:融合后的分割结果(只显示边界框)
    img1_fused = create_fused_image_with_boxes(current_image, current_fused_masks, current_fused_boxes, current_fused_scores, "Fused Results")
    img2_fused = create_fused_image_with_boxes(history_image, history_fused_masks, history_fused_boxes, history_fused_scores, "Fused Results")
     
    if img1_fused.mode != 'RGB':
        img1_fused = img1_fused.convert('RGB')
    if img2_fused.mode != 'RGB':
        img2_fused = img2_fused.convert('RGB')
     
    axes[1, 0].imshow(img1_fused)
    axes[1, 0].set_title(f"Current Frame - Fused ({len(current_fused_masks)} targets)", fontsize=16, fontweight='bold')
    axes[1, 0].axis('off')
     
    axes[1, 1].imshow(img2_fused)
    axes[1, 1].set_title(f"History Frame - Fused ({len(history_fused_masks)} targets)", fontsize=16, fontweight='bold')
    axes[1, 1].axis('off')
     
    # 第三行:匹配结果(显示mask和边界框)
    img1_matched = create_matched_image(current_image, current_fused_masks, current_fused_boxes, current_fused_scores,
                                       matches, matcher, is_current_frame=True, alpha=0.5)
    img2_matched = create_matched_image(history_image, history_fused_masks, history_fused_boxes, history_fused_scores,
                                      matches, matcher, is_current_frame=False, alpha=0.5)
     
    if img1_matched.mode != 'RGB':
        img1_matched = img1_matched.convert('RGB')
    if img2_matched.mode != 'RGB':
        img2_matched = img2_matched.convert('RGB')
     
    axes[2, 0].imshow(img1_matched)
    axes[2, 0].set_title(f"Current Frame - Matching Results", fontsize=16, fontweight='bold')
    axes[2, 0].axis('off')
     
    axes[2, 1].imshow(img2_matched)
    axes[2, 1].set_title(f"History Frame - Matching Results ({len(matches)} matches)", fontsize=16, fontweight='bold')
    axes[2, 1].axis('off')
     
    # 添加图例说明(英文)
    fig.text(0.02, 0.02,
             "Label Legend:\n"
             "C{i}: Current frame target ID, H{i}: History frame target ID\n"
             "Sim: Feature similarity score, Mask: Segmentation confidence\n"
             "->: Match direction, Solid box: Matched, Dashed box: Unmatched\n"
             "Same color indicates same match pair",
             fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
     
    plt.tight_layout()
     
    # 保存结果
    timestamp = int(time.time())
    result_path = os.path.join(output_dir, f"multi_to_one_matching_result_{timestamp}.png")
    plt.savefig(result_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Matching results saved: {result_path}")
     
    # 保存特征数据
    save_features(current_fused_masks, current_fused_boxes, current_features, current_image, output_dir, "current_frame")
    save_features(history_fused_masks, history_fused_boxes, history_features, history_image, output_dir, "history_frame")
     
    plt.show()
     
    return result_path
 
 
def save_features(masks, boxes, features, image, output_dir, image_name):
    """保存特征数据"""
    feature_data = {
        'image_name': image_name,
        'num_targets': len(masks),
        'features': [f.tolist() if f is not None else None for f in features],
        'boxes': [box.cpu().numpy().tolist() for box in boxes],
        'mask_shapes': [mask.shape for mask in masks],
        'timestamp': time.time()
    }
     
    feature_path = os.path.join(output_dir, f"{image_name}_features.json")
    with open(feature_path, 'w') as f:
        json.dump(feature_data, f, indent=2)
     
    print(f"{image_name} 特征已保存: {feature_path}")
 
def process_relocalization_with_multi_to_one(processor, current_image_path, history_image_path, output_dir="output"):
    """使用多对一匹配策略进行重定位处理"""
    os.makedirs(output_dir, exist_ok=True)
     
    print("=" * 60)
    print("🚀 开始多对一重定位匹配处理")
    print("=" * 60)
     
    # 处理当前帧
    print("📷 处理当前帧...")
    current_masks, current_boxes, current_scores, current_image = get_image_mask(processor, current_image_path)
     
    print("📷 处理历史帧...")
    history_masks, history_boxes, history_scores, history_image = get_image_mask(processor, history_image_path)
     
    print("=" * 60)
    print("🔄 开始融合重叠掩码...")
     
    # 融合重叠的mask
    current_fused_masks, current_fused_boxes, current_fused_scores = fuse_overlapping_masks(
        current_masks, current_boxes, current_scores, iou_threshold=0.5, overlap_threshold=0.6
    )
     
    history_fused_masks, history_fused_boxes, history_fused_scores = fuse_overlapping_masks(
        history_masks, history_boxes, history_scores, iou_threshold=0.5, overlap_threshold=0.6
    )
     
    print("=" * 60)
    print("🎯 开始多对一重定位匹配...")
     
    # 创建匹配器并使用多对一匹配策略
    matcher = ImageMatcher(reid_threshold=0.7, matching_strategy="multi_to_one")
     
    # 执行多对一匹配
    matches, current_features, history_features = matcher.match_images_multi_to_one(
        current_image=current_image,
        current_masks=current_fused_masks,
        current_boxes=current_fused_boxes,
        history_image=history_image,
        history_masks=history_fused_masks,
        history_boxes=history_fused_boxes
    )
     
    print("=" * 60)
    print("📊 开始可视化结果...")
     
    # 可视化结果
    result_path = visualize_results(
        current_image, current_masks, current_boxes, current_scores,
        history_image, history_masks, history_boxes, history_scores,
        current_fused_masks, current_fused_boxes, current_fused_scores,
        history_fused_masks, history_fused_boxes, history_fused_scores,
        matches, current_features, history_features, output_dir
    )
     
    # 清理内存
    del current_masks, current_boxes, current_scores, history_masks, history_boxes, history_scores
    del current_fused_masks, current_fused_boxes, current_fused_scores
    del history_fused_masks, history_fused_boxes, history_fused_scores
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
     
    return {
        'matches': matches,
        'current_features': current_features,
        'history_features': history_features,
        'result_path': result_path
    }
 
def main():
    """主函数"""
    # 清空GPU缓存
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
 
    processor = load_model("sam3.pt")
     
    # 设置模型为评估模式
    processor.model.eval()
 
    # 示例图像路径 - 请修改为您的实际图像路径
    # current_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0952.JPG"
    # history_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0966.JPG"
    current_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0952.JPG"
    history_image_path = "/home/r9000k/v0_data/rtk/300_location_1130_15pm/images/DJI_0966.JPG"
    output_dir = "multi_to_one_matching_output"
     
    try:
        results = process_relocalization_with_multi_to_one(processor, current_image_path, history_image_path, output_dir)
         
        print(f"\n🎉 多对一重定位匹配处理完成!")
        print(f"📁 匹配结果保存在: {results['result_path']}")
        print(f"🔗 找到 {len(results['matches'])} 个匹配")
         
        # 显示匹配详情
        print(f"\n📋 匹配详情:")
        for i, (current_idx, history_idx, similarity) in enumerate(results['matches']):
            print(f"  匹配{i}: 当前帧目标{current_idx} → 历史帧目标{history_idx}, 相似度: {similarity:.3f}")
         
    except Exception as e:
        print(f"❌ 处理过程中出错: {e}")
        import traceback
        traceback.print_exc()
    finally:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
 
if __name__ == "__main__":
    main()

  

posted on 2025-12-10 11:13  MKT-porter  阅读(0)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3