SAM3源码学习;mask_to_polygons优化;mask_decoder的批量掩码格式;

1.SAM3源码学习
直接进源码去看接口,一般开发者都在接口处写明方法的作用,包含大量备注。
SAM3图像的几个后处理方法:
post_process_semantic_segmentation(语义分割)
post_process_object_detection(目标检测)
post_process_instance_segmentation(实力分割)

2.mask_to_polygons优化
在将SAM3应用到遥感图象时,提高置信度后出现了拓扑错误;
一个 polygon 必须满足:
✔ 外环必须闭合,没有自交
✔ 内部不能有乱七八糟的小洞
✔ 洞必须在面内
✔ 洞不能刚好碰到外边界
✔ 多个洞不能乱交叉
✔ 面不能自交叉(像“8”字形)
否则就会出现拓扑错误。

将现有的 mask_to_polygons 函数中加入“过滤太小连通域” + “过滤太小轮廓” 的代码:
✔ 不会引入 Shapely 拓扑错误
✔ 自动过滤面积小的噪声、小洞、小碎片
“过滤太小连通域” :
针对的是掩码中的小碎片、小点、小噪声连通区域
“过滤太小轮廓”
针对的是掩码中轮廓很小,但却是大物体的一部分

def mask_to_polygons(mask, min_region_area=100, min_polygon_area=5):
    """
    Convert mask to polygons with filtering of small regions.
    
    Args:
        mask: torch tensor, shape [H, W], values 0/1
        min_region_area: minimum connected component area to keep
        min_polygon_area: minimum polygon (in pixel units) area to keep
    """
    # ------------ 1. Tensor → uint8 mask ------------
    mask_np = (mask.cpu().numpy().astype(np.uint8))

    # ------------ 2. 过滤太小连通域(关键步骤) ------------
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=4)

    clean_mask = np.zeros_like(mask_np)
    for i in range(1, num_labels):  # skip background = 0
        area = stats[i, cv2.CC_STAT_AREA]
        if area >= min_region_area:
            clean_mask[labels == i] = 1

    # 转换成 255 图用于 findContours
    mask_img = (clean_mask * 255).astype(np.uint8)

    # ------------ 3. 提取轮廓 ------------
    contours, _ = cv2.findContours(mask_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    polys = []
    for cnt in contours:
        if len(cnt) < 3:
            continue

        poly = Polygon(cnt[:, 0, :])

        # ------------ 4. 过滤太小的 polygon(关键步骤)------------
        if poly.area >= min_polygon_area:
            polys.append(poly)

    if not polys:
        return []

    # ------------ 5. 合并 polygon(union)------------
    merged = unary_union(polys)

    # ------------ 6. 保证输出为列表 ------------
    if isinstance(merged, Polygon):
        return [merged]
    elif isinstance(merged, MultiPolygon):
        return list(merged.geoms)
    else:
        return []

3.mask_decoder的批量掩码格式
mask_decoder 的批量掩码,格式为 (batch_size, num_channels, height, width)。
(1)batch_size:一次 forward 输入的图像数量。batch_size=4,一次性处理了 4 张图片。
(2)num_channels —— 最重要(代表掩码数量),在不同模式下它含义不同:
(实例掩码 / 文本掩码 / 多个 mask 预测)
情况 1:提示式分割(点/框) → num_channels = 多个掩码候选(通常 3 个掩码候选:低、中、高质量)
情况 2:文本语义分割(Text Prompt) → num_channels = 类别数
情况 3:视频跟踪(SAM2/SAM3 Tracker) → num_channels = 跟踪 ID 数量
(3)height, width(特征图尺寸 / 掩码尺寸)

posted @ 2026-01-23 11:57  asphyxiasea  阅读(2)  评论(0)    收藏  举报