Swin Transformer深度解析:从原理到代码实现,让视觉Transformer高效处理高分辨率图像

在计算机视觉领域,Transformer架构的引入曾引发一场革命,但原始Transformer在处理高分辨率图像时面临计算量爆炸的难题。Swin Transformer(Shifted Window Transformer)通过分层结构和移位窗口注意力机制,巧妙地将计算复杂度从平方级降至线性,成为视觉任务的新标杆。本文将从核心概念、数学原理到代码实现,带你全面掌握这一强大模型。

一、Swin Transformer的核心创新:窗口注意力与分层特征

Swin Transformer专为视觉任务设计,其核心创新在于两点:分层特征提取窗口注意力机制。与直接处理整张图像的原始Transformer不同,Swin模仿CNN的层级结构,通过Patch Merging逐步缩小特征图尺寸、提升通道数,从而适配不同尺度的视觉特征。这种设计让模型既能捕捉全局语义,又能保留局部细节。

窗口注意力(Window Attention):将特征图划分为不重叠的窗口,仅在窗口内计算自注意力。这一操作将计算复杂度从 O\left ( HW \right )^{2} 降低到 O\left (\left ( HW/M^{2} \right ) M^{2} \right )=O\left ( HW \right ) ,其中M为窗口大小。这意味着当图像分辨率增大时,计算量不再呈平方级增长,而是线性增长,极大提升了效率。

移位窗口注意力(Shifted Window Attention):窗口划分会导致窗口间信息孤立。Swin通过循环移位生成新窗口,同时引入掩码(Mask)避免无效计算,确保窗口内注意力的正确性。这种机制让信息在相邻窗口间流动,实现了跨窗口交互。

⚠️ 注意:Swin Transformer的设计灵感部分来自CNN的分层特性,但其核心计算仍基于注意力机制。这种混合思路在Python、JavaScript等语言的深度学习框架中均有实现,例如PyTorch和TensorFlow。

二、数学公式与原理详解

Swin Transformer的数学基础并不复杂,但每一步都精心设计。以下是关键公式的解析:

  1. Patch Partition(分块):将原始图像(H×W×3)划分为大小为4×4的不重叠Patch,每个Patch展平为一维向量:Patch(i,j)=Image\left [ 4i:4i+4,4j:4j+4,: \right ]\rightarrow R^{4\times 4\times 3}\rightarrow R ^{48} 。最终得到 (H/4×W/4)×48 的特征图,记为 H_{0},W_{0},C_{0}
  2. Window Attention 计算:在每个窗口内计算自注意力,核心公式与标准自注意力一致,但作用域限制在窗口内:Attention(Q,K,V)=Softmax\left ( \frac{QK^{T}}{\sqrt{d_{k}}}+M \right )V 。其中Q、K、V为查询/键/值矩阵,由输入特征线性变换得到,dk为Q/K的维度;M为掩码矩阵,仅在Shifted Window时生效,用于屏蔽移位后跨原始窗口的无效注意力计算。
  3. Shifted Window 移位操作:设窗口大小为M,特征图尺寸为H×W,移位量为⌊M/2⌋,移位后坐标变换:\left ( x^{'},y^{'} \right )= \left ( x-\left \lfloor M/2 \right \rfloor mod H, y-\left \lfloor M/2 \right \rfloor mod W\right )
  4. Patch Merging(分层下采样):将2×2相邻Patch拼接,通道数翻倍,尺寸减半:Out\left [ i,j,: \right ]=In\left [ 2i:2i+2,2j:2j+2,: \right ]\rightarrow R^{4C}\rightarrow R^{2C} 。最终特征图尺寸变为H/2×W/2×2C。

这些公式在Go、TypeScript等语言中也有类似实现,但主流深度学习框架(如PyTorch for Python)提供了最便捷的API。

三、代码实现与实例解析

理论需结合实践。以下代码展示了Swin Transformer的核心模块,使用Python和PyTorch框架实现:

模块一:核心代码导入

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

模块二:窗口注意力

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        """
        窗口注意力模块初始化
        Args:
            dim: 输入特征的通道数(如 96、192)
            window_size: 窗口大小(M),默认7,代表7×7的窗口
            num_heads: 多头注意力的头数,需满足 dim % num_heads == 0
        """
        super().__init__()
        # 保存基础参数
        self.dim = dim                      # 输入通道数
        self.window_size = window_size      # 窗口大小 M
        self.num_heads = num_heads          # 注意力头数
        self.head_dim = dim // num_heads    # 每个注意力头的维度
        self.scale = self.head_dim ** -0.5  # 缩放因子 1/√d_k,防止内积值过大
        # 线性变换层:将输入特征一次性映射为Q、K、V(效率高于3个独立线性层)
        # 输入dim → 输出3*dim(Q/K/V各占dim)
        self.qkv = nn.Linear(dim, dim * 3)
        # 输出投影层:将注意力计算后的特征映射回原维度
        self.proj = nn.Linear(dim, dim)
        # 相对位置偏置表:解决绝对位置编码的局限性,捕捉窗口内位置关系
        # 尺寸:(2M-1)×(2M-1) × num_heads → 覆盖窗口内所有可能的相对位置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        # -------------------------- 预计算相对位置索引 --------------------------
        # 生成窗口内的坐标网格:[0,1,...,M-1]
        coords = torch.arange(self.window_size)
        # 生成2×M×M的坐标矩阵:coords_grid[0]是行坐标,coords_grid[1]是列坐标
        coords_grid = torch.stack(torch.meshgrid([coords, coords], indexing="ij"))
        # 展平坐标:2 × M²(把M×M的坐标拉成一维)
        coords_flatten = torch.flatten(coords_grid, 1)
        # 计算所有位置对的相对坐标:(2, M², M²) → 每个位置相对于其他位置的偏移
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        # 维度重排:(M², M², 2) → [位置i, 位置j, 行/列偏移]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # 将相对坐标从[-M+1, M-1]映射到[0, 2M-2](避免负数索引)
        relative_coords[:, :, 0] += self.window_size - 1  # 行偏移修正
        relative_coords[:, :, 1] += self.window_size - 1  # 列偏移修正
        # 行索引编码:行偏移 × (2M-1) + 列偏移 → 唯一标识每个相对位置
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        # 求和得到最终的相对位置索引:(M², M²)
        relative_position_index = relative_coords.sum(-1)
        # 注册为缓冲区(不参与梯度更新)
        self.register_buffer("relative_position_index", relative_position_index)
    def forward(self, x, mask=None):
        """
        窗口注意力前向传播
        Args:
            x: 输入特征,形状 [num_windows*B, M², dim]
               - num_windows: 特征图划分的窗口总数
               - B: batch size
               - M²: 单个窗口的像素数
               - dim: 通道数
            mask: 掩码矩阵(仅移位窗口时使用),形状 [num_windows, M², M²]
        Returns:
            output: 注意力计算后的特征,形状 [num_windows*B, M², dim]
        """
        # 获取输入维度:B_=num_windows*B, N=M², C=dim
        B_, N, C = x.shape
        # -------------------------- 生成Q/K/V --------------------------
        # 1. 线性变换:[B_, N, C] → [B_, N, 3*C]
        # 2. 维度重排:[B_, N, 3, num_heads, head_dim] → 拆分3个维度给Q/K/V
        # 3. 维度置换:[3, B_, num_heads, N, head_dim]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # 拆分Q/K/V:每个的形状都是 [B_, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        # -------------------------- 计算注意力分数 --------------------------
        q = q * self.scale  # 缩放Q,对应公式中的 1/√d_k
        # Q @ K^T:[B_, num_heads, N, head_dim] × [B_, num_heads, head_dim, N] → [B_, num_heads, N, N]
        attn = (q @ k.transpose(-2, -1))
        # -------------------------- 添加相对位置偏置 --------------------------
        # 从偏置表中取出对应位置的偏置:[M²*M², num_heads] → [M², M², num_heads]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1
        )
        # 维度重排:[num_heads, M², M²]
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        # 添加偏置:[B_, num_heads, N, N] + [1, num_heads, N, N] → 广播相加
        attn = attn + relative_position_bias.unsqueeze(0)
        # -------------------------- 应用掩码(移位窗口专用) --------------------------
        if mask is not None:
            nW = mask.shape[0]  # 获取窗口数量
            # 1. 维度适配:将attn拆分为 [B//nW, nW, num_heads, N, N]
            # 2. 掩码广播:mask [nW, N, N] → [1, nW, 1, N, N]
            # 3. 相加:将掩码值(-100)加到跨窗口的注意力分数上
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # 还原维度:[B_, num_heads, N, N]
            attn = attn.view(-1, self.num_heads, N, N)
            # Softmax归一化:掩码位置的-100会被Softmax为0,不参与计算
            attn = F.softmax(attn, dim=-1)
        else:
            # 无掩码时直接Softmax
            attn = F.softmax(attn, dim=-1)
        # -------------------------- 注意力加权求和V --------------------------
        # attn [B_, num_heads, N, N] × v [B_, num_heads, N, head_dim] → [B_, num_heads, N, head_dim]
        # 维度置换:[B_, N, num_heads, head_dim] → 合并头维度 → [B_, N, dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        # -------------------------- 输出投影 --------------------------
        x = self.proj(x)  # 线性变换,保持维度不变
        return x
# ========================== 辅助函数:窗口划分与还原 ==========================
def window_partition(x, window_size):
    """
    将特征图划分为不重叠的窗口
    Args:
        x: 输入特征,形状 [B, H, W, C]
        window_size: 窗口大小 M
    Returns:
        windows: 窗口化特征,形状 [num_windows*B, M, M, C]
                 num_windows = (H/M) × (W/M)
    """
    B, H, W, C = x.shape
    # 维度拆分:[B, H, W, C] → [B, H//M, M, W//M, M, C]
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # 维度置换:[B, H//M, W//M, M, M, C] → 合并前三维 → [num_windows*B, M, M, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
def window_reverse(windows, window_size, H, W):
    """
    将窗口化特征还原为完整特征图(window_partition的逆操作)
    Args:
        windows: 窗口化特征,形状 [num_windows*B, M, M, C]
        window_size: 窗口大小 M
        H, W: 原始特征图的高和宽
    Returns:
        x: 还原后的特征图,形状 [B, H, W, C]
    """
    # 计算batch size:num_windows = (H*W)/(M*M) → B = total_windows / num_windows
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # 维度拆分:[num_windows*B, M, M, C] → [B, H//M, W//M, M, M, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, C)
    # 维度置换:[B, H//M, M, W//M, M, C] → 合并维度 → [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
    return x

模块三:Swin Transformer Block

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift_size=0):
        """
        Swin Transformer 基础块(包含窗口注意力/移位窗口注意力)
        Args:
            dim: 输入通道数
            num_heads: 注意力头数
            window_size: 窗口大小 M
            shift_size: 移位量(0=普通窗口,M//2=移位窗口)
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size  # 移位量,核心参数
        # 层归一化(Transformer标准操作,放在注意力前)
        self.norm1 = nn.LayerNorm(dim)
        # 实例化窗口注意力模块
        self.attn = WindowAttention(dim, window_size, num_heads)
    def forward(self, x):
        """
        Swin Block 前向传播
        Args:
            x: 输入特征,形状 [B, H, W, C]
        Returns:
            x: 输出特征,形状 [B, H, W, C](残差连接后)
        """
        B, H, W, C = x.shape
        shortcut = x  # 保存残差连接的输入
        # 1. 层归一化
        x = self.norm1(x)
        # -------------------------- 移位操作(Shifted Window) --------------------------
        if self.shift_size > 0:
            # 循环移位:向左、向上移动shift_size个像素(负数表示左/上移)
            # 例如M=7,shift_size=3 → 左移3,上移3
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            # -------------------------- 生成移位窗口的掩码 --------------------------
            # 1. 初始化掩码矩阵:[1, H, W, 1],用于标记不同原始区域
            img_mask = torch.zeros((1, H, W, 1), device=x.device)
            # 2. 划分移位后的区域切片(共3×3=9个区域)
            h_slices = (slice(0, -self.window_size),          # 上半部分
                        slice(-self.window_size, -self.shift_size),  # 中间过渡区
                        slice(-self.shift_size, None))        # 下半部分
            w_slices = (slice(0, -self.window_size),          # 左半部分
                        slice(-self.window_size, -self.shift_size),  # 中间过渡区
                        slice(-self.shift_size, None))        # 右半部分
            # 3. 为每个区域分配唯一标签(0-8)
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # 4. 将掩码划分为窗口:[num_windows, M, M, 1]
            mask_windows = window_partition(img_mask, self.window_size)
            # 5. 展平掩码:[num_windows, M²]
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # 6. 计算注意力掩码:判断两个位置是否属于同一原始区域
            #    - 同一区域:mask=0 → Softmax后正常计算
            #    - 不同区域:mask=-100 → Softmax后为0,不参与计算
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            # 普通窗口:不移位,无掩码
            shifted_x = x
            attn_mask = None
        # -------------------------- 窗口注意力计算 --------------------------
        # 1. 划分窗口:[B, H, W, C] → [num_windows*B, M, M, C]
        x_windows = window_partition(shifted_x, self.window_size)
        # 2. 展平窗口:[num_windows*B, M², C](适配WindowAttention输入)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        # 3. 窗口注意力前向计算
        attn_windows = self.attn(x_windows, mask=attn_mask)
        # -------------------------- 还原窗口为特征图 --------------------------
        # 1. 还原窗口形状:[num_windows*B, M², C] → [num_windows*B, M, M, C]
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        # 2. 窗口还原为特征图:[num_windows*B, M, M, C] → [B, H, W, C]
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        # -------------------------- 逆移位(恢复原始位置) --------------------------
        if self.shift_size > 0:
            # 向右、向下移动shift_size个像素,还原到原始位置
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        # -------------------------- 残差连接 --------------------------
        x = shortcut + x  # 残差相加,提升梯度传播
        return x

模块四:Patch Merging(分层下采样)

class PatchMerging(nn.Module):
    def __init__(self, dim):
        """
        Patch Merging 模块:将2×2相邻Patch合并,实现下采样
        Args:
            dim: 输入通道数
        """
        super().__init__()
        self.dim = dim
        # 线性变换:将4*dim通道压缩为2*dim(下采样后通道数翻倍)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        # 层归一化:放在线性变换前,提升稳定性
        self.norm = nn.LayerNorm(4 * dim)
    def forward(self, x):
        """
        Patch Merging 前向传播
        Args:
            x: 输入特征,形状 [B, H, W, C]
        Returns:
            x: 下采样后的特征,形状 [B, H/2, W/2, 2C]
        """
        B, H, W, C = x.shape
        # -------------------------- 2×2 Patch合并 --------------------------
        # 1. 维度拆分:[B, H, W, C] → [B, H//2, 2, W//2, 2, C]
        x = x.view(B, H // 2, 2, W // 2, 2, C)
        # 2. 维度置换:[B, H//2, W//2, 2, 2, C] → 合并最后三维 → [B, H//2, W//2, 4C]
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H // 2, W // 2, -1)
        # -------------------------- 归一化+通道压缩 --------------------------
        x = self.norm(x)       # 层归一化
        x = self.reduction(x)  # 4C → 2C,通道数翻倍
        return x

模块五:测试代码

if __name__ == "__main__":
    # 模拟输入:batch_size=2,特征图56×56,通道数96(Swin-T的第一层特征)
    x = torch.randn(2, 56, 56, 96)
    # 1. 测试普通窗口注意力块(无移位)
    block1 = SwinTransformerBlock(dim=96, num_heads=8, window_size=7, shift_size=0)
    out1 = block1(x)
    print("普通窗口注意力输出形状:", out1.shape)  # 预期:torch.Size([2, 56, 56, 96])
    # 2. 测试移位窗口注意力块(移位量=3)
    block2 = SwinTransformerBlock(dim=96, num_heads=8, window_size=7, shift_size=3)
    out2 = block2(x)
    print("移位窗口注意力输出形状:", out2.shape)  # 预期:torch.Size([2, 56, 56, 96])
    # 3. 测试Patch Merging下采样
    patch_merge = PatchMerging(dim=96)
    out3 = patch_merge(x)
    print("Patch Merging输出形状:", out3.shape)  # 预期:torch.Size([2, 28, 28, 192])

运行结果示例:

  • 普通窗口注意力输出形状: torch.Size([2, 56, 56, 96])
  • 移位窗口注意力输出形状: torch.Size([2, 56, 56, 96])
  • Patch Merging输出形状: torch.Size([2, 28, 28, 192])

输入一张真实图片(比如猫、狗的照片),模型输出这张图片的“高级特征”——不再是像素,而是能描述“这是猫、那是狗”的特征。再搭配简单的分类/检测头,就能实现图片分类、目标检测、语义分割等视觉任务。

[AFFILIATE_SLOT_1] 如果你对深度学习框架感兴趣,推荐学习PyTorch官方教程,它提供了Swin Transformer的完整示例。

四、关键实现细节与注意事项

在实际开发中,有几个细节需要特别注意:

  • 窗口划分与还原:窗口注意力需要高效的特征图划分和还原操作,否则会成为性能瓶颈。
  • 移位+掩码:移位窗口的实现依赖循环移位和掩码矩阵,确保信息流动同时避免无效计算。
  • Patch Merging:分层下采样模仿CNN的层级特征,但需注意通道数的变化。

✅ 实践建议:在Python中使用PyTorch实现时,建议用view和permute操作代替循环,提升效率。对于JavaScript或TypeScript开发者,ONNX Runtime Web提供了浏览器端推理支持。

⚠️ 常见错误:窗口大小M必须能整除特征图尺寸,否则需要填充或裁剪。此外,移位窗口的掩码矩阵需正确生成,否则会导致注意力计算错误。

五、总结与延伸

核心创新:Swin Transformer用「窗口注意力」降低计算复杂度,用「移位窗口注意力」解决窗口间信息孤立问题,适配高分辨率视觉任务。

数学核心:窗口内自注意力公式中的掩码M是移位窗口的关键。

代码核心:窗口划分/还原是窗口注意力的基础;移位+掩码是实现Shifted Window的核心;Patch Merging实现分层下采样,模仿CNN的层级特征。

[AFFILIATE_SLOT_2] 延伸阅读:Swin Transformer的变体Swin v2进一步优化了训练稳定性,适用于更大规模的视觉模型。在Go或Java中,你可以通过TensorFlow Java API或GoML库进行推理部署。

一句话总结:Swin Transformer用“分窗口算注意力”降低计算量,用“挪窗口+掩码”打通窗口间信息,用“合并像素块”构建分层特征,最终让Transformer能高效处理图片,既快又能学到有用的特征。

posted on 2026-06-13 18:33  wgwyanfs  阅读(22)  评论(0)    收藏  举报

导航