SAM3源码学习;mask_to_polygons优化;mask_decoder的批量掩码格式;
1.SAM3源码学习
直接进源码去看接口,一般开发者都在接口处写明方法的作用,包含大量备注。
SAM3图像的几个后处理方法:
post_process_semantic_segmentation(语义分割)
post_process_object_detection(目标检测)
post_process_instance_segmentation(实力分割)
2.mask_to_polygons优化
在将SAM3应用到遥感图象时,提高置信度后出现了拓扑错误;
一个 polygon 必须满足:
✔ 外环必须闭合,没有自交
✔ 内部不能有乱七八糟的小洞
✔ 洞必须在面内
✔ 洞不能刚好碰到外边界
✔ 多个洞不能乱交叉
✔ 面不能自交叉(像“8”字形)
否则就会出现拓扑错误。
将现有的 mask_to_polygons 函数中加入“过滤太小连通域” + “过滤太小轮廓” 的代码:
✔ 不会引入 Shapely 拓扑错误
✔ 自动过滤面积小的噪声、小洞、小碎片
“过滤太小连通域” :
针对的是掩码中的小碎片、小点、小噪声连通区域
“过滤太小轮廓”
针对的是掩码中轮廓很小,但却是大物体的一部分
def mask_to_polygons(mask, min_region_area=100, min_polygon_area=5):
"""
Convert mask to polygons with filtering of small regions.
Args:
mask: torch tensor, shape [H, W], values 0/1
min_region_area: minimum connected component area to keep
min_polygon_area: minimum polygon (in pixel units) area to keep
"""
# ------------ 1. Tensor → uint8 mask ------------
mask_np = (mask.cpu().numpy().astype(np.uint8))
# ------------ 2. 过滤太小连通域(关键步骤) ------------
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=4)
clean_mask = np.zeros_like(mask_np)
for i in range(1, num_labels): # skip background = 0
area = stats[i, cv2.CC_STAT_AREA]
if area >= min_region_area:
clean_mask[labels == i] = 1
# 转换成 255 图用于 findContours
mask_img = (clean_mask * 255).astype(np.uint8)
# ------------ 3. 提取轮廓 ------------
contours, _ = cv2.findContours(mask_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polys = []
for cnt in contours:
if len(cnt) < 3:
continue
poly = Polygon(cnt[:, 0, :])
# ------------ 4. 过滤太小的 polygon(关键步骤)------------
if poly.area >= min_polygon_area:
polys.append(poly)
if not polys:
return []
# ------------ 5. 合并 polygon(union)------------
merged = unary_union(polys)
# ------------ 6. 保证输出为列表 ------------
if isinstance(merged, Polygon):
return [merged]
elif isinstance(merged, MultiPolygon):
return list(merged.geoms)
else:
return []
3.mask_decoder的批量掩码格式
mask_decoder 的批量掩码,格式为 (batch_size, num_channels, height, width)。
(1)batch_size:一次 forward 输入的图像数量。batch_size=4,一次性处理了 4 张图片。
(2)num_channels —— 最重要(代表掩码数量),在不同模式下它含义不同:
(实例掩码 / 文本掩码 / 多个 mask 预测)
情况 1:提示式分割(点/框) → num_channels = 多个掩码候选(通常 3 个掩码候选:低、中、高质量)
情况 2:文本语义分割(Text Prompt) → num_channels = 类别数
情况 3:视频跟踪(SAM2/SAM3 Tracker) → num_channels = 跟踪 ID 数量
(3)height, width(特征图尺寸 / 掩码尺寸)
浙公网安备 33010602011771号