py合并labels下的重复在一起的标签(nms)
nms_merge_labels.py
"""
YOLO标签文件NMS合并脚本
对labels文件夹中的每个txt文件进行NMS处理,合并重叠的检测框
"""
import os
import numpy as np
from pathlib import Path
def calculate_iou(box1, box2):
"""
计算两个YOLO格式边界框的IoU
box格式: [center_x, center_y, width, height]
"""
# 转换为[x1, y1, x2, y2]格式
box1_x1 = box1[0] - box1[2] / 2
box1_y1 = box1[1] - box1[3] / 2
box1_x2 = box1[0] + box1[2] / 2
box1_y2 = box1[1] + box1[3] / 2
box2_x1 = box2[0] - box2[2] / 2
box2_y1 = box2[1] - box2[3] / 2
box2_x2 = box2[0] + box2[2] / 2
box2_y2 = box2[1] + box2[3] / 2
# 计算交集区域
x1 = max(box1_x1, box2_x1)
y1 = max(box1_y1, box2_y1)
x2 = min(box1_x2, box2_x2)
y2 = min(box1_y2, box2_y2)
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
# 计算并集区域
box1_area = box1[2] * box1[3]
box2_area = box2[2] * box2[3]
union = box1_area + box2_area - intersection
return intersection / union if union > 0 else 0.0
def nms(boxes, iou_threshold=0.5):
"""
对检测框进行NMS处理
boxes: list of [class_id, center_x, center_y, width, height]
iou_threshold: IoU阈值,超过此值的框会被合并
"""
if len(boxes) == 0:
return []
# 按类别分组
class_boxes = {}
for box in boxes:
class_id = int(box[0])
if class_id not in class_boxes:
class_boxes[class_id] = []
class_boxes[class_id].append(box)
# 对每个类别分别进行NMS
result = []
for class_id in class_boxes:
boxes_of_class = class_boxes[class_id]
# 按面积排序(保留较大的框)
boxes_of_class = sorted(boxes_of_class, key=lambda x: x[3] * x[4], reverse=True)
keep = []
while boxes_of_class:
# 选择面积最大的框
current = boxes_of_class.pop(0)
keep.append(current)
# 移除与当前框IoU超过阈值的框
boxes_of_class = [box for box in boxes_of_class
if calculate_iou(current[1:5], box[1:5]) < iou_threshold]
result.extend(keep)
# 按类别ID排序结果
result.sort(key=lambda x: x[0])
return result
def merge_boxes(box1, box2):
"""
合并两个重叠的检测框
使用加权平均,权重为框的面积
"""
area1 = box1[3] * box1[4]
area2 = box2[3] * box2[4]
total_area = area1 + area2
if total_area == 0:
return box1
weight1 = area1 / total_area
weight2 = area2 / total_area
merged = [
box1[0], # class_id保持不变
box1[1] * weight1 + box2[1] * weight2, # center_x
box1[2] * weight1 + box2[2] * weight2, # center_y
box1[3] * weight1 + box2[3] * weight2, # width
box1[4] * weight1 + box2[4] * weight2, # height
]
return merged
def process_label_file(file_path, iou_threshold=0.5, merge_mode=True):
"""
处理单个标签文件
merge_mode: True表示合并重叠框,False表示删除重叠框
"""
# 读取标签文件
boxes = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
parts = line.split()
if len(parts) >= 5:
boxes.append([float(parts[0]), float(parts[1]), float(parts[2]),
float(parts[3]), float(parts[4])])
if len(boxes) == 0:
return 0, 0
original_count = len(boxes)
if merge_mode:
# 合并模式:对重叠的框进行合并
# 按类别分组
class_boxes = {}
for box in boxes:
class_id = int(box[0])
if class_id not in class_boxes:
class_boxes[class_id] = []
class_boxes[class_id].append(box)
result = []
for class_id in class_boxes:
boxes_of_class = class_boxes[class_id]
# 按面积排序
boxes_of_class = sorted(boxes_of_class, key=lambda x: x[3] * x[4], reverse=True)
keep = []
while boxes_of_class:
current = boxes_of_class.pop(0)
# 查找与当前框重叠的框
overlapping = [current]
to_remove = []
for i, box in enumerate(boxes_of_class):
iou = calculate_iou(current[1:5], box[1:5])
if iou >= iou_threshold:
overlapping.append(box)
to_remove.append(i)
# 合并重叠的框
if len(overlapping) > 1:
merged = overlapping[0]
for box in overlapping[1:]:
merged = merge_boxes(merged, box)
keep.append(merged)
else:
keep.append(current)
# 移除已处理的框
for i in sorted(to_remove, reverse=True):
boxes_of_class.pop(i)
result.extend(keep)
# 按类别ID排序
result.sort(key=lambda x: x[0])
else:
# 删除模式:使用标准NMS
result = nms(boxes, iou_threshold)
# 写回文件
with open(file_path, 'w', encoding='utf-8') as f:
for box in result:
f.write(f"{int(box[0])} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} {box[4]:.6f}\n")
return original_count, len(result)
def main():
labels_dir = Path("labels")
if not labels_dir.exists():
print(f"错误: 找不到labels文件夹")
return
# 获取所有txt文件
label_files = list(labels_dir.glob("*.txt"))
if not label_files:
print(f"错误: labels文件夹中没有找到txt文件")
return
print(f"找到 {len(label_files)} 个标签文件")
print("=" * 60)
# NMS参数
iou_threshold = 0.5 # IoU阈值
merge_mode = True # True: 合并重叠框, False: 删除重叠框
total_original = 0
total_after = 0
# 处理每个文件
for i, file_path in enumerate(label_files, 1):
original_count, after_count = process_label_file(file_path, iou_threshold, merge_mode)
total_original += original_count
total_after += after_count
if original_count != after_count:
print(f"[{i}/{len(label_files)}] {file_path.name}: {original_count} -> {after_count} (减少 {original_count - after_count})")
else:
if i % 50 == 0 or i == len(label_files):
print(f"[{i}/{len(label_files)}] 已处理...")
print("=" * 60)
print(f"处理完成!")
print(f"总检测框数量: {total_original} -> {total_after}")
print(f"减少了: {total_original - total_after} 个检测框")
print(f"减少比例: {(total_original - total_after) / total_original * 100:.2f}%")
if __name__ == "__main__":
main()

浙公网安备 33010602011771号