py合并labels下的重复在一起的标签(nms)

nms_merge_labels.py

"""
YOLO标签文件NMS合并脚本
对labels文件夹中的每个txt文件进行NMS处理,合并重叠的检测框
"""

import os
import numpy as np
from pathlib import Path


def calculate_iou(box1, box2):
    """
    计算两个YOLO格式边界框的IoU
    box格式: [center_x, center_y, width, height]
    """
    # 转换为[x1, y1, x2, y2]格式
    box1_x1 = box1[0] - box1[2] / 2
    box1_y1 = box1[1] - box1[3] / 2
    box1_x2 = box1[0] + box1[2] / 2
    box1_y2 = box1[1] + box1[3] / 2
    
    box2_x1 = box2[0] - box2[2] / 2
    box2_y1 = box2[1] - box2[3] / 2
    box2_x2 = box2[0] + box2[2] / 2
    box2_y2 = box2[1] + box2[3] / 2
    
    # 计算交集区域
    x1 = max(box1_x1, box2_x1)
    y1 = max(box1_y1, box2_y1)
    x2 = min(box1_x2, box2_x2)
    y2 = min(box1_y2, box2_y2)
    
    if x2 <= x1 or y2 <= y1:
        return 0.0
    
    intersection = (x2 - x1) * (y2 - y1)
    
    # 计算并集区域
    box1_area = box1[2] * box1[3]
    box2_area = box2[2] * box2[3]
    union = box1_area + box2_area - intersection
    
    return intersection / union if union > 0 else 0.0


def nms(boxes, iou_threshold=0.5):
    """
    对检测框进行NMS处理
    boxes: list of [class_id, center_x, center_y, width, height]
    iou_threshold: IoU阈值,超过此值的框会被合并
    """
    if len(boxes) == 0:
        return []
    
    # 按类别分组
    class_boxes = {}
    for box in boxes:
        class_id = int(box[0])
        if class_id not in class_boxes:
            class_boxes[class_id] = []
        class_boxes[class_id].append(box)
    
    # 对每个类别分别进行NMS
    result = []
    for class_id in class_boxes:
        boxes_of_class = class_boxes[class_id]
        
        # 按面积排序(保留较大的框)
        boxes_of_class = sorted(boxes_of_class, key=lambda x: x[3] * x[4], reverse=True)
        
        keep = []
        while boxes_of_class:
            # 选择面积最大的框
            current = boxes_of_class.pop(0)
            keep.append(current)
            
            # 移除与当前框IoU超过阈值的框
            boxes_of_class = [box for box in boxes_of_class 
                             if calculate_iou(current[1:5], box[1:5]) < iou_threshold]
        
        result.extend(keep)
    
    # 按类别ID排序结果
    result.sort(key=lambda x: x[0])
    
    return result


def merge_boxes(box1, box2):
    """
    合并两个重叠的检测框
    使用加权平均,权重为框的面积
    """
    area1 = box1[3] * box1[4]
    area2 = box2[3] * box2[4]
    total_area = area1 + area2
    
    if total_area == 0:
        return box1
    
    weight1 = area1 / total_area
    weight2 = area2 / total_area
    
    merged = [
        box1[0],  # class_id保持不变
        box1[1] * weight1 + box2[1] * weight2,  # center_x
        box1[2] * weight1 + box2[2] * weight2,  # center_y
        box1[3] * weight1 + box2[3] * weight2,  # width
        box1[4] * weight1 + box2[4] * weight2,  # height
    ]
    
    return merged


def process_label_file(file_path, iou_threshold=0.5, merge_mode=True):
    """
    处理单个标签文件
    merge_mode: True表示合并重叠框,False表示删除重叠框
    """
    # 读取标签文件
    boxes = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split()
                if len(parts) >= 5:
                    boxes.append([float(parts[0]), float(parts[1]), float(parts[2]), 
                                 float(parts[3]), float(parts[4])])
    
    if len(boxes) == 0:
        return 0, 0
    
    original_count = len(boxes)
    
    if merge_mode:
        # 合并模式:对重叠的框进行合并
        # 按类别分组
        class_boxes = {}
        for box in boxes:
            class_id = int(box[0])
            if class_id not in class_boxes:
                class_boxes[class_id] = []
            class_boxes[class_id].append(box)
        
        result = []
        for class_id in class_boxes:
            boxes_of_class = class_boxes[class_id]
            
            # 按面积排序
            boxes_of_class = sorted(boxes_of_class, key=lambda x: x[3] * x[4], reverse=True)
            
            keep = []
            while boxes_of_class:
                current = boxes_of_class.pop(0)
                
                # 查找与当前框重叠的框
                overlapping = [current]
                to_remove = []
                for i, box in enumerate(boxes_of_class):
                    iou = calculate_iou(current[1:5], box[1:5])
                    if iou >= iou_threshold:
                        overlapping.append(box)
                        to_remove.append(i)
                
                # 合并重叠的框
                if len(overlapping) > 1:
                    merged = overlapping[0]
                    for box in overlapping[1:]:
                        merged = merge_boxes(merged, box)
                    keep.append(merged)
                else:
                    keep.append(current)
                
                # 移除已处理的框
                for i in sorted(to_remove, reverse=True):
                    boxes_of_class.pop(i)
            
            result.extend(keep)
        
        # 按类别ID排序
        result.sort(key=lambda x: x[0])
    else:
        # 删除模式:使用标准NMS
        result = nms(boxes, iou_threshold)
    
    # 写回文件
    with open(file_path, 'w', encoding='utf-8') as f:
        for box in result:
            f.write(f"{int(box[0])} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} {box[4]:.6f}\n")
    
    return original_count, len(result)


def main():
    labels_dir = Path("labels")
    
    if not labels_dir.exists():
        print(f"错误: 找不到labels文件夹")
        return
    
    # 获取所有txt文件
    label_files = list(labels_dir.glob("*.txt"))
    
    if not label_files:
        print(f"错误: labels文件夹中没有找到txt文件")
        return
    
    print(f"找到 {len(label_files)} 个标签文件")
    print("=" * 60)
    
    # NMS参数
    iou_threshold = 0.5  # IoU阈值
    merge_mode = True    # True: 合并重叠框, False: 删除重叠框
    
    total_original = 0
    total_after = 0
    
    # 处理每个文件
    for i, file_path in enumerate(label_files, 1):
        original_count, after_count = process_label_file(file_path, iou_threshold, merge_mode)
        total_original += original_count
        total_after += after_count
        
        if original_count != after_count:
            print(f"[{i}/{len(label_files)}] {file_path.name}: {original_count} -> {after_count} (减少 {original_count - after_count})")
        else:
            if i % 50 == 0 or i == len(label_files):
                print(f"[{i}/{len(label_files)}] 已处理...")
    
    print("=" * 60)
    print(f"处理完成!")
    print(f"总检测框数量: {total_original} -> {total_after}")
    print(f"减少了: {total_original - total_after} 个检测框")
    print(f"减少比例: {(total_original - total_after) / total_original * 100:.2f}%")


if __name__ == "__main__":
    main()

posted @ 2026-04-17 17:39  小城熊儿  阅读(4)  评论(0)    收藏  举报