常用代码段-nms操作
非极大值抑制(Non-Maximum Suppression,NMS)是一种常用于目标检测和计算机视觉任务的技术,用于从重叠的检测框中选择最佳的候选框。以下是使用 PyTorch 实现标准的 NMS 算法的示例代码:
import torch def nms(boxes, scores, iou_threshold): sorted_indices = scores.argsort(descending=True) selected_indices = [] while sorted_indices.numel() > 0: if sorted_indices.numel() == 1: selected_indices.append(sorted_indices.item()) break current_index = sorted_indices[0] selected_indices.append(current_index.item()) current_box = boxes[current_index] other_boxes = boxes[sorted_indices[1:]] ious = calculate_iou(current_box, other_boxes) valid_indices = (ious <= iou_threshold).nonzero().squeeze() if valid_indices.numel() == 0: break sorted_indices = sorted_indices[valid_indices + 1] return selected_indices def calculate_iou(box, boxes): x1 = torch.max(box[0], boxes[:, 0]) y1 = torch.max(box[1], boxes[:, 1]) x2 = torch.min(box[2], boxes[:, 2]) y2 = torch.min(box[3], boxes[:, 3]) intersection_area = torch.clamp(x2 - x1 + 1, min=0) * torch.clamp(y2 - y1 + 1, min=0) box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1) boxes_area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) iou = intersection_area / (box_area + boxes_area - intersection_area) return iou # 示例数据:框坐标和置信度得分 boxes = torch.tensor([[100, 100, 200, 200], [120, 120, 220, 220], [150, 150, 250, 250]]) scores = torch.tensor([0.9, 0.8, 0.7]) # NMS 参数 iou_threshold = 0.5 # 执行 NMS 算法 selected_indices = nms(boxes, scores, iou_threshold) print("选择的索引:", selected_indices)
在此示例中,我们首先定义了 nms 函数来执行 NMS 算法。然后,我们实现了一个简单的 calculate_iou 函数来计算两个框的交并比(IoU)。最后,我们使用示例数据 boxes 和 scores 运行 NMS 算法,并打印出选定的索引。
本文来自博客园,作者:海_纳百川,转载请注明原文链接:https://www.cnblogs.com/chentiao/p/17656004.html,如有侵权联系删除

浙公网安备 33010602011771号