3.27日报

实现了热辐射引导的人体检测器

class ThermalGuidedDetector(nn.Module):
    """创新点4:热辐射引导的人体检测器"""
    def __init__(self, num_classes=1):
        super().__init__()
        # 可见光分支
        self.vis_backbone = resnet50(pretrained=True)
        self.vis_backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        # 红外分支
        self.ir_backbone = resnet50(pretrained=True)
        self.ir_backbone.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        # 跨模态注意力融合
        self.cma1 = CrossModalAttention(256)
        self.cma2 = CrossModalAttention(512)
        self.cma3 = CrossModalAttention(1024)
        
        # 检测头
        self.detection_head = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, num_classes*5, kernel_size=1)  # 5表示4个坐标+1个置信度
        )
        
    def forward(self, vis_img, ir_img):
        # 可见光特征提取
        vis_x = self.vis_backbone.conv1(vis_img)
        vis_x = self.vis_backbone.bn1(vis_x)
        vis_x = self.vis_backbone.relu(vis_x)
        vis_x = self.vis_backbone.maxpool(vis_x)
        
        vis_layer1 = self.vis_backbone.layer1(vis_x)
        vis_layer2 = self.vis_backbone.layer2(vis_layer1)
        vis_layer3 = self.vis_backbone.layer3(vis_layer2)
        vis_layer4 = self.vis_backbone.layer4(vis_layer3)
        
        # 红外特征提取
        ir_x = self.ir_backbone.conv1(ir_img)
        ir_x = self.ir_backbone.bn1(ir_x)
        ir_x = self.ir_backbone.relu(ir_x)
        ir_x = self.ir_backbone.maxpool(ir_x)
        
        ir_layer1 = self.ir_backbone.layer1(ir_x)
        ir_layer2 = self.ir_backbone.layer2(ir_layer1)
        ir_layer3 = self.ir_backbone.layer3(ir_layer2)
        ir_layer4 = self.ir_backbone.layer4(ir_layer3)
        
        # 跨模态特征融合
        fused_layer1 = self.cma1(vis_layer1, ir_layer1)
        fused_layer2 = self.cma2(vis_layer2, ir_layer2)
        fused_layer3 = self.cma3(vis_layer3, ir_layer3)
        
        # 特征聚合
        combined = torch.cat([fused_layer3, vis_layer4], dim=1)
        
        # 检测输出
        predictions = self.detection_head(combined)
        return predictions

# 使用示例
detector = ThermalGuidedDetector()
vis_input = torch.randn(1, 3, 512, 512)  # 可见光图像
ir_input = torch.randn(1, 1, 512, 512)   # 红外图像
output = detector(vis_input, ir_input)

 

posted @ 2025-04-09 11:07  Code13  阅读(8)  评论(0)    收藏  举报