YOLOv11改进 - 卷积Conv | PATConv(Partial Attention Convolution)部分注意力卷积,在减少计算量的同时融合卷积

前言

本文提出部分注意力卷积(PATConv)机制,并将其集成到YOLOv11中。传统神经网络中,卷积计算密集,注意力机制全局计算冗余,此前的“部分卷积”会丢失未计算通道的特征价值。PATConv通过“通道拆分 - 并行处理 - 结果拼接”的逻辑,给不同通道分配“擅长的任务”,兼顾局部与全局特征。基于此,它还衍生出PAT_ch、PAT_sp和PAT_sf三种细分模块。此外,还提出了动态部分卷积(DPConv),并构建了新的混合网络家族PartialNet。我们将PATConv代码集成到YOLOv11中,实验表明,改进后的YOLOv11在I 目标检测表现出色。

文章目录: YOLOv11改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv11改进专栏

介绍

image-20251225221831512

摘要

设计一种能够在不牺牲准确率和吞吐量的前提下,使网络保持低参数量和 FLOPs 的模块或机制,仍然是一个挑战。为了解决这一挑战并利用特征图通道内的冗余,我们提出了一种新的解决方案:部分通道机制(Partial Channel Mechanism, PCM)。具体而言,通过分割操作,特征图通道被划分为不同的部分,每一部分对应不同的操作,例如卷积、注意力、池化和恒等映射。基于这一假设,我们引入了一种新颖的部分注意力卷积(Partial Attention Convolution, PATConv),它能够高效地将卷积与视觉注意力结合起来。我们的探索表明,PATConv 可以完全替代常规卷积和常规视觉注意力,同时减少模型参数和 FLOPs。此外,PATConv 还可以衍生出三种新型模块:部分通道注意力模块(PAT ch)部分空间注意力模块(PAT sp)部分自注意力模块(PAT sf)。另外,我们还提出了一种新颖的动态部分卷积(Dynamic Partial Convolution, DPConv),它可以自适应地学习不同层中分割通道的比例,以实现更好的性能权衡。基于 PATConv 和 DPConv,我们提出了一个新的混合网络家族,命名为 PartialNet。在 ImageNet-1K 分类任务上,其 Top-1 准确率和推理速度均优于部分 SOTA 模型,并且在 COCO 数据集的目标检测和分割任务中表现出色。我们的代码已在 https://github.com/haiduo/PartialNet 开源。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

PATConv(Partial Attention Convolution,部分注意力卷积) 核心思路是通过“通道拆分-并行处理-结果拼接”的逻辑,在减少计算量的同时融合卷积与注意力的优势,解决传统卷积“计算密集”或传统注意力“全局计算冗余”的问题,具体原理可拆分为以下4个关键步骤:

1. 核心设计背景:解决“计算效率与特征完整性”的矛盾

传统神经网络中,卷积擅长捕捉局部细节但需对所有通道密集计算(参数多、耗资源),注意力机制(如通道注意力、空间注意力)擅长捕捉全局关联但需对全通道计算(冗余度高);而此前的“部分卷积”(如FasterNet的PConv)仅对部分通道做卷积、其余通道直接保留,虽快但丢失了未计算通道的特征价值。

PATConv的核心突破是:不浪费任何通道,而是给不同通道分配“擅长的任务”——让部分通道做卷积(抓局部)、另一部分通道做注意力(抓全局),两者并行计算后合并,既减少总计算量,又兼顾局部与全局特征。

2. 三大核心操作:拆分→并行处理→拼接

PATConv的完整流程围绕“如何高效利用每部分通道”展开,具体分为3步:

(1)通道拆分:按比例分割输入特征通道

首先对输入的特征图(可理解为模型处理后的“图像特征”,形状为“高度×宽度×通道数”)做通道拆分:根据预设或自适应学习的“拆分比例”(如3:1、2:1),将所有通道分成两部分(记为“卷积通道组”和“注意力通道组”)。

  • 例:若输入通道数为64,拆分比例为3/4,则48个通道进入“卷积通道组”,16个通道进入“注意力通道组”;
  • 拆分的关键是“不丢弃任何通道”,确保每部分通道都参与有价值的计算。

(2)并行处理:给两组通道分配专属任务

对拆分后的两组通道,分别用“更高效的算子”处理,避免冗余计算:

  • 卷积通道组:仅对这部分通道应用卷积(如3×3卷积、1×1卷积),负责捕捉局部细节(如图像的边缘、纹理)。由于只处理部分通道,相比“全通道卷积”,计算量(FLOPs)可直接按拆分比例减少(如处理1/4通道,卷积计算量仅为原来的1/4)。
  • 注意力通道组:仅对这部分通道应用轻量化注意力(非全通道注意力),负责捕捉全局关联(如“天空与云朵常同时出现”“眼睛与鼻子的位置关联”)。由于注意力仅作用于部分通道,避免了传统全通道注意力的“二次计算复杂度”(如自注意力的O(n²)复杂度),速度大幅提升。

(3)结果拼接:合并两组通道的特征

将“卷积通道组”处理后的局部特征,与“注意力通道组”处理后的全局特征,通过“通道拼接”操作合并为完整的特征图,作为PATConv的输出。

  • 拼接后通道数与输入通道数一致(拆分的两部分相加),保证特征维度不丢失,同时融合了局部与全局信息,特征表达更全面。

3. 衍生三大模块:适配不同任务的注意力需求

基于上述核心逻辑,PATConv可根据“需要捕捉的注意力类型”,衍生出3种细分模块,分别适配不同的特征学习需求:

(1)PAT_ch(部分通道注意力块):抓“通道间关联”

  • 作用:解决“不同通道的特征重要性差异”(如“红色通道”对识别“苹果”更重要,“绿色通道”对识别“树叶”更重要)。
  • 实现:将“卷积通道组”用3×3卷积处理(抓局部),“注意力通道组”用“增强型高斯通道注意力”处理——通过计算通道的均值和方差(而非仅均值,比传统SE-Net更全面),评估每个通道的重要性,给重要通道分配更高权重。
  • 适用场景:替代传统卷积或深度可分离卷积(DWConv),提升通道特征的区分度。

(2)PAT_sp(部分空间注意力块):抓“空间位置关联”

  • 作用:解决“不同空间位置的特征关联”(如“猫的耳朵”通常在“猫的头部”上方)。
  • 实现:将“卷积通道组”用1×1卷积处理(压缩通道维度、减少计算),“注意力通道组”用“空间注意力”处理——先将通道压缩为1个通道(如用1×1卷积),生成“空间注意力图”(每个位置的数值代表该位置的重要性),再与原特征相乘,强化重要空间位置的特征。
  • 优化细节:可与模型中的MLP层(多层感知机)合并计算,进一步减少推理时的延迟(避免重复计算1×1卷积)。

(3)PAT_sf(部分自注意力块):抓“全局长距离关联”

  • 作用:解决“全图范围的特征关联”(如“汽车”与“道路”的全局构图关联),但避免传统自注意力的高复杂度。

  • 实现:仅在模型的最后一层使用(减少全局计算的次数),“卷积通道组”用常规卷积处理,“注意力通道组”用“轻量化自注意力”处理——加入相对位置编码(RPE),让自注意力更精准捕捉位置关系,同时仅对部分通道计算,大幅降低复杂度。

  • 适用场景:模型的最后阶段,补充全局视野,提升分类、检测等任务的精度。

核心代码



class PATConv(nn.Module):
    def __init__(self, dim, n_div=4, forward_type='split_cat', channel_type='se', patnet_t0=True): #'se' if i_stage <= 2 else 'self',
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim = dim
        self.n_div = n_div
        self.dim_untouched = dim - self.dim_conv3
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
        self.channel_type = channel_type

        if channel_type == 'self':
            self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
            rpe_config = get_rpe_config(
                ratio=20,
                method="euc",
                mode='bias',
                shared_head=False,
                skip=0,
                rpe_on='k',
            )
            if patnet_t0:
                num_heads = 4
            else:
                num_heads = 6
            self.attn = RPEAttention(self.dim_untouched, num_heads=num_heads, attn_drop=0.1, proj_drop=0.1,
                                     rpe_config=rpe_config)
            self.norm = timm.layers.LayerNorm2d(self.dim_untouched)
            # self.norm = timm.layers.LayerNorm2d(self.dim)
            self.forward = self.forward_atten
        elif channel_type == 'se':
            self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
            self.attn = SRM(self.dim_untouched)
            self.norm = nn.BatchNorm2d(self.dim_untouched)
            self.forward = self.forward_atten
        else:  # channel_type is empty string or other
            if forward_type == 'slicing':
                self.forward = self.forward_slicing
            elif forward_type == 'split_cat':
                self.forward = self.forward_split_cat
            else:
                raise NotImplementedError

    def forward_atten(self, x: Tensor) -> Tensor:
        if self.channel_type == 'se':
            x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
            x1 = self.partial_conv3(x1)
            # x = self.partial_conv3(x)
            x2 = self.attn(x2)
            x2 = self.norm(x2)
            x = torch.cat((x1, x2), 1)
            # x = self.attn(x)
        else:  # channel_type == 'self'
            x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
            x1 = self.partial_conv3(x1)
            x2 = self.norm(x2)
            x2 = self.attn(x2)
            x = torch.cat((x1, x2), 1)
        return x

    def forward_slicing(self, x: Tensor) -> Tensor:
        x1 = x.clone()  # !!! Keep the original input intact for the residual connection later
        x1[:, :self.dim_conv3, :, :] = self.partial_conv3(x1[:, :self.dim_conv3, :, :])
        return x1

    def forward_split_cat(self, x: Tensor) -> Tensor:
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)
        return x
posted @ 2025-12-25 22:48  魔改工程师  阅读(12)  评论(0)    收藏  举报