CBAM注意力机制的示例

image

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttentionModule(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels)
        )

    def forward(self, x):
        avg_pool = self.avg_pool(x).view(x.size(0), -1)
        channel_att = torch.sigmoid(self.fc(avg_pool)).view(x.size(0), x.size(1), 1, 1)
        return x * channel_att

class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionModule, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)

    def forward(self, x):
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        spatial_att = torch.cat([max_pool, avg_pool], dim=1)
        spatial_att = torch.sigmoid(self.conv(spatial_att))
        return x * spatial_att

class CBAMModule(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, spatial_kernel_size=7):
        super(CBAMModule, self).__init__()
        self.channel_att = ChannelAttentionModule(in_channels, reduction_ratio)
        self.spatial_att = SpatialAttentionModule(kernel_size=spatial_kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

# Example of using CBAM in a neural network
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        # Your existing model layers here
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        # ...

        # Adding CBAM module
        self.cbam = CBAMModule(in_channels=64)  # Adjust in_channels based on your network architecture

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # ...

        # Applying CBAM
        x = self.cbam(x)

        return x

posted @ 2024-01-18 10:10  辛宣  阅读(101)  评论(0)    收藏  举报