完整教程:[sam2图像分割] mask_decoder | TwoWayTransformer

第六章:掩膜解码器

欢迎回来

第五章:提示编码器中,我们看到了SAM-2如何理解你的具体指令(如点击或方框)。现在SAM-2已经能"看"到视觉世界(多亏了图像编码器)并"听懂"你的命令(通过提示编码器),接下来呢?它需要真正执行你要求的任务:绘制分割掩膜!

SAM-2的"艺术家"

想象你有一位技艺高超的艺术家,需要绘制一个特定对象。你向他们展示场景的照片(来自图像编码器的图像特征),并明确告诉他们要画什么以及在哪里画(来自提示编码器的提示嵌入)。艺术家随后综合所有信息,理解你的要求,并细致地绘制出目标对象。

掩膜解码器就是SAM2基础模型中的这位艺术家。它的主要工作是结合图像的高级视觉理解和你的提示指令,智能地**"绘制"实际的分割掩膜**,围绕你指定的对象。它还会给你一个评分,告诉你它对自己的绘制结果有多自信!

解决的问题

掩膜解码器解决的核心问题是**基于复杂的视觉信息和用户提供的提示,生成精确的像素级分割掩膜**。

它是连接AI对图像的内部理解与你的高级请求的桥梁,将抽象的数字转化为对象的具体轮廓。它确保当你点击一只猫时,你会得到一个漂亮且准确的仅包含猫的掩膜,而不是背景或其他对象的一部分。

概念

让我们探索掩膜解码器背后的关键思想:

  1. 信息整合:掩膜解码器是终极的混合器。它接收两种主要成分:

    • 图像特征:来自图像编码器的整个图像的详细视觉理解
    • 提示嵌入:你的具体指令(如点击、方框或掩膜),由提示编码器转化为AI的语言
  2. Transformer"大脑":掩膜解码器的核心是一个强大的神经网络,称为Transformer,具体来说是一个TwoWayTransformer这个Transformer非常擅长让不同的信息(如图像特征和提示嵌入)“对话”。它找出你的提示与图像特征的关系,突出图像中与你的请求相关的部分。

  3. 特殊的绘制"令牌":掩膜解码器使用特殊的数值"绘制工具",称为掩膜令牌IoU令牌

    • 掩膜令牌:这些就像不同的画笔或风格。Transformer处理这些令牌以及你的提示和图像特征,学习用它们生成不同的掩膜预测。SAM-2通常可以为同一对象生成几个略有不同的掩膜,以处理模糊情况(例如重叠对象)。
    • IoU令牌:这是一个特殊令牌,模型用它预测每个生成掩膜的置信度分数。这个分数告诉你模型认为它的掩膜预测有多好(接近1.0的分数表示高置信度)。
  4. 掩膜上采样:Transformer在低分辨率表示上运行以提高效率。生成低分辨率掩膜后,掩膜解码器使用一系列上采样层(如ConvTranspose2d)将掩膜"放大"回输入图像的原始尺寸。这确保最终掩膜细节丰富且完美贴合。

掩膜解码器的使用方式

你不会直接在代码中调用MaskDecoder。相反,它是SAM2基础模型的重要内部组件,当你使用高级工具(如SAM2ImagePredictor)的predict()方法时会被激活

让我们回顾第一章:SAM2图像预测器中的SAM2ImagePredictor示例,看看掩膜解码器在哪里发挥作用:

from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2_hf
import numpy as np
from PIL import Image
# 1. 加载核心SAM-2模型和预测器(如前几章所示)
sam_model = build_sam2_hf(model_id="facebook/sam2-hiera-large")
predictor = SAM2ImagePredictor(sam_model)
# 2. 设置虚拟图像(图像编码器已内部处理)
my_image = np.zeros((256, 256, 3), dtype=np.uint8)
predictor.set_image(my_image)
# 3. 提供提示(提示编码器已内部翻译)
point_coords = np.array([[128, 128]]) # 中心附近的点击
point_labels = np.array([1])         # 标签1表示前景
# 此调用*内部*激活掩膜解码器!
masks, scores, low_res_masks_logits = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False # 我们要求一个主掩膜以简化
)
print(f"预测掩膜形状:{masks.shape}")
print(f"掩膜的置信度分数(IoU):{scores.item():.2f}")
print("掩膜解码器已成功绘制掩膜并给出置信度分数")

说明:当你调用predictor.predict()并传入point_coordspoint_labels时,SAM2ImagePredictor会将这些与预计算的图像特征一起发送给SAM2Base Model

SAM2Base Model内部,掩膜解码器接收这些输入,处理它们,生成原始掩膜预测(low_res_masks_logits),并估计其质量(scores)。

SAM2ImagePredictor随后对这些原始输出进行后处理,给你最终的干净masksscores

幕后揭秘:掩膜解码器的工作原理

让我们揭开层层迷雾,了解掩膜解码器如何协调这一"绘制"过程。

工作流程

SAM2ImagePredictor要求SAM2Base Model生成掩膜时,以下是掩膜解码器内部的简化操作:

  1. 输入到达掩膜解码器接收处理后的image_embeddings(来自图像编码器)、image_pe(图像的位置编码)以及sparse_prompt_embeddingsdense_prompt_embeddings(来自提示编码器)。
  2. 准备Transformer输入:它为内部的TwoWayTransformer准备一组"令牌"。这些令牌包括iou_token(用于置信度预测)、mask_tokens(用于生成掩膜)和你的sparse_prompt_embeddings。它还将dense_prompt_embeddingsimage_pe直接添加到image_embeddings中以丰富它们。
  3. Transformer交互:所有这些输入被送入TwoWayTransformer。这个强大的组件执行多轮"注意力"计算,让提示信息影响图像特征,反之亦然。它有效地在图像特征中找到提示描述的对象。Transformer输出精炼后的令牌和图像特征。
  4. 生成掩膜系数:从精炼的mask_tokens(由Transformer输出)中,掩膜解码器使用小型神经网络(output_hypernetworks_mlps)生成"超网络系数"。可以将其视为绘制掩膜的具体笔触指令。
  5. 上采样与合并:同时,掩膜解码器使用output_upscaling网络获取精炼的图像特征并提高其分辨率。这些上采样后的图像特征就像一张高分辨率画布。生成的"超网络系数"随后与这张高分辨率画布结合,"绘制"出详细的low_res_masks
  6. 预测置信度:精炼的iou_token(也由Transformer输出)被发送到iou_prediction_head(一个小型MLP)。这个头预测iou_predictions,即每个生成掩膜的置信度分数。
  7. 输出掩膜解码器随后将这些原始masksiou_predictions返回给SAM2Base Model,后者将它们传递给SAM2ImagePredictor进行最终调整和清理,然后呈现给你。

以下是此流程的简化序列图:

在这里插入图片描述

代码

在这里插入图片描述

让我们看看sam2/modeling/sam/mask_decoder.pysam2/modeling/sam/transformer.py文件中的关键部分,了解这些步骤如何实现。

  1. 掩膜解码器初始化(__init__
    掩膜解码器创建时(作为SAM2Base Model的一部分),它会设置其主要组件:

    # 摘自sam2/modeling/sam/mask_decoder.py(简化版)
    class MaskDecoder(nn.Module):
    def __init__(
    self,
    transformer_dim: int,
    transformer: nn.Module, # 这是TwoWayTransformer!
    num_multimask_outputs: int = 3,
    # ... 其他参数 ...
    ) -> None:
    super().__init__()
    self.transformer = transformer # 核心逻辑组合器
    self.iou_token = nn.Embedding(1, transformer_dim) # 用于置信度分数的特殊令牌
    self.num_mask_tokens = num_multimask_outputs + 1
    self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) # 用于绘制掩膜的特殊令牌
    self.output_upscaling = nn.Sequential( # 提高掩膜分辨率的网络
    nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
    LayerNorm2d(transformer_dim // 4), # 帮助稳定训练
    nn.GELU(), # 激活函数
    nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
    nn.GELU(),
    )
    self.output_hypernetworks_mlps = nn.ModuleList( # 生成掩膜系数的网络
    [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
    )
    self.iou_prediction_head = MLP( # 预测置信度分数的网络
    transformer_dim, 256, self.num_mask_tokens, 3 # 示例维度
    )

    说明掩膜解码器初始化其核心self.transformer(混合器)。它还设置self.iou_tokenself.mask_tokens,这些是学习到的特殊数值,作为Transformer生成置信度分数和各种掩膜输出的提示

    • output_upscaling是一个小型神经网络,用于放大最终掩膜;
    • output_hypernetworks_mlps生成细粒度掩膜细节;
    • iou_prediction_head计算置信度。
  2. MaskDecoder.predict_masks(核心逻辑)
    这是掩膜生成真正发生的主要内部方法。由MaskDecoder.forward方法调用。

    # 摘自sam2/modeling/sam/mask_decoder.py(简化版)
    # 在MaskDecoder类内部
    def predict_masks(
    self,
    image_embeddings: torch.Tensor,       # 来自图像编码器
    image_pe: torch.Tensor,               # 图像的位置编码
    sparse_prompt_embeddings: torch.Tensor, # 来自提示编码器(点/框)
    dense_prompt_embeddings: torch.Tensor,  # 来自提示编码器(掩膜输入)
    # ... 其他参数 ...
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    # 1. 为Transformer的查询输入组合所有"令牌"
    # 包括IoU令牌、掩膜令牌和用户的稀疏提示
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    # 扩展令牌以匹配批次大小
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
    # 2. 为Transformer的键/值输入准备图像特征和位置编码
    b, c, h, w = image_embeddings.shape # 获取图像尺寸
    src = image_embeddings + dense_prompt_embeddings # 将图像特征与密集提示结合
    pos_src = image_pe # 使用图像位置编码
    # 3. 重塑并运行核心TwoWayTransformer以混合所有信息
    # Transformer期望B x N_tokens x C,所以将HxW展平为N_tokens
    src_flat = src.flatten(2).permute(0, 2, 1) # B x (H*W) x C
    pos_src_flat = pos_src.flatten(2).permute(0, 2, 1) # B x (H*W) x C
    hs, src_out = self.transformer(src_flat, pos_src_flat, tokens) # hs是令牌特征,src_out是注意力后的图像特征
    # 4. 分离Transformer输出的IoU和掩膜令牌
    iou_token_out = hs[:, 0, :] # 'hs'中的第一个令牌通常是IoU令牌
    mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # 其余令牌是掩膜令牌
    # 5. 上采样精炼的图像特征(src_out)并生成掩膜预测
    # 将src_out重塑回类似图像的特征(B, C, H, W)
    src_out = src_out.transpose(1, 2).view(b, c, h, w)
    upscaled_embedding = self.output_upscaling(src_out) # 提高分辨率
    hyper_in_list = []
    for i in range(self.num_mask_tokens):
    # 每个掩膜令牌(来自mask_tokens_out)生成绘制掩膜的系数
    hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
    hyper_in = torch.stack(hyper_in_list, dim=1) # 堆叠这些系数
    # 6. 将系数与上采样图像特征结合以"绘制"掩膜
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
    # 7. 使用IoU令牌预测置信度分数(IoU)
    iou_pred = self.iou_prediction_head(iou_token_out)
    return masks, iou_pred # 返回原始掩膜及其置信度分数

    说明:此方法是核心。

它接收丰富的image_embeddingsdense_prompt_embeddings,将它们结合并重塑以供self.transformer使用。

  • 它还将你的sparse_prompt_embeddings与特殊的iou_tokenmask_tokens结合为Transformer的查询。
  • Transformer(self.transformer)随后混合所有这些信息,生成精炼的iou_token_outmask_tokens_out
  • mask_tokens_outself.output_hypernetworks_mlps用于生成系数,随后与upscaled_embedding(来自self.output_upscaling)结合以"绘制"masks
  • 最后,iou_token_outself.iou_prediction_head用于预测iou_pred(置信度分数)。

在这里插入图片描述

  1. TwoWayTransformer.forward(掩膜解码器内部)
    TwoWayTransformer是掩膜解码器内部的智能混合器。以下是其主forward方法的简化版:

    # 摘自sam2/modeling/sam/transformer.py(简化版)
    class TwoWayTransformer(nn.Module):
    # ... 初始化设置注意力层 ...
    def forward(
    self,
    image_embedding: Tensor, # 精炼的图像特征(来自图像编码器+密集提示)
    image_pe: Tensor,        # 图像的位置编码
    point_embedding: Tensor, # 所有组合令牌(IoU、掩膜、稀疏提示)
    ) -> Tuple[Tensor, Tensor]:
    # 将image_embedding和image_pe从Bx(H*W)xC重塑为Bx(H*W)xC(已展平)
    # 'queries'是掩膜和提示令牌,'keys'是图像特征
    queries = point_embedding
    keys = image_embedding
    # 应用一系列Transformer块
    # 每层是一个TwoWayAttentionBlock,允许queries和keys交互
    for layer in self.layers:
    queries, keys = layer(
    queries=queries,     # 令牌
    keys=keys,           # 图像特征
    query_pe=point_embedding, # 令牌的位置信息
    key_pe=image_pe,     # 图像特征的位置信息
    )
    # 应用从令牌到图像特征的最终注意力层
    # ...(进一步注意力和归一化)...
    return queries, keys # 返回精炼的令牌和图像特征

    说明TwoWayTransformer接收处理后的image_embedding(来自图像编码器加上密集提示)和point_embedding(你的稀疏提示加上掩膜和IoU令牌)。

  • 它随后通过多个TwoWayAttentionBlock层传递这些输入。这些块包含自注意力(令牌与令牌对话)和交叉注意力(令牌与图像对话,图像与令牌对话),实现深度双向交互。
  • 此过程精炼queries(令牌)和keys(图像特征),使它们准备好用于掩膜解码器中的最终掩膜生成步骤。

总结

掩膜解码器是SAM-2的关键"艺术家",负责将所有信息结合起来精确绘制所需的分割掩膜

通过使用强大的Transformer架构智能结合图像编码器的视觉特征和提示编码器的指令丰富嵌入,它生成准确的掩膜预测及置信度分数。这是将请求转化为具体分割对象的最后一步。

现在我们已经了解SAM-2如何看、听和绘 单张图像,接下来探索它如何在视频中跨时间记忆对象。下一章,我们将深入SAM-2中的"记忆"概念!

下一章:记忆编码器

posted @ 2025-11-23 10:33  gccbuaa  阅读(0)  评论(0)    收藏  举报