目标检测算法学习:SSD

from sys import gettrace
from typing import ForwardRef
import torch
from torch.nn.modules import flatten
import torchvision
from torch import nn
from torch.nn import functional as F 
from d2l import torch as d2l
#来自于李沐深度学习课程https://zh-v2.d2l.ai/chapter_preface/index.html
#类别预测层
def cls_predictor(num_inputs,num_anchors,num_classes):
    return  nn.Conv2d(num_inputs,num_anchors*(num_classes+1),
    kernel_size=3,padding=1)

#边界框预测层
def bbox_predictor(num_inputs,num_anchors):
    return nn.Conv2d(num_inputs,num_anchors*4,kernel_size=3,padding=1)

#连接多尺度预测层
def forward(x,block):
    return block(x)

# Y1 = forward(torch.zeros((2,8,20,20)),cls_predictor(8,5,10))

def flatten_pred(pred):
    return torch.flatten(pred.permute(0,2,3,1),start_dim=1)

def concat_preds(preds):
    return torch.cat([flatten_pred(p) for p in preds],dim=1)

#高宽减半

def down_sample_blk(in_channels,out_channels):
    blk=[]
    for _ in range(2):
        blk.append(nn.Conv2d(in_channels,out_channels,
        kernel_size=3,padding=1))
        blk.append(nn.BatchNorm2d(out_channels))
        blk.append(nn.ReLU())
        in_channels = out_channels
    blk.append(nn.MaxPool2d(2))
    return nn.Sequential(*blk)

#基本网络块,用于从输入图像中抽取特征,网络块输出的特征图为:32*32(256/(2*2*2)=32)
def base_net():
    blk=[]
    num_filters = [3,16,32,64]
    for i in range(len(num_filters)-1):
        blk.append(down_sample_blk(num_filters[i],num_filters[i+1]))
    return nn.Sequential(*blk)

def get_blk(i):
    if i==0:
        blk=base_net()
    elif i==1:
        blk=down_sample_blk(64,128)
    elif i==4:
        blk = nn.AdaptiveAvgPool2d((1,1))
    else:
        blk = down_sample_blk(128,128)
    return blk

#为每一个块定义前向计算, 输出包括:特征图、生成的锚框,预测的锚框的类别和偏移量
def blk_forward(X,blk,size,ratio,cls_predictor,bbox_predictor):
    Y = blk(X)
    anchors = d2l.multibox_prior(Y,sizes=size,ratios=ratio)
    cls_preds = cls_predictor(Y)
    bbox_preds = bbox_predictor(Y)
    return(Y,anchors,cls_preds,bbox_preds)

#定义完整的模型

class TinySSD(nn.Module):
    def __init__(self,num_classes,**kwargs):
        super(TinySSD,self).__init__(**kwargs)
        self.num_classes = num_classes
        idx_to_in_channels = [64,128,128,128,128]
        for i in range(5):
            #即赋值语句
            setattr(self,f'blk_{i}',get_blk(i))
            setattr(self,f'cls_{i}',cls_predictor(idx_to_in_channels[i],
            num_anchors,num_classes))
            setattr(self,f'bbox_[i]',bbox_predictor(idx_to_in_channels[i],
            num_anchors))
    
    def forward(self,X):
        anchors,cls_preds,bbox_preds = [None]*5,[None]*5,[None]*5
        for i in range(5):
            X,anchors[i],cls_preds[i],bbox_preds[i] = blk_forward(
                X,getattr(self,f'blk_{i}'),sizes[i],ratios[i],
                getattr(self,f'cls_{i}'),getattr(self,f'bbox_{i}')
            )
        anchors = torch.cat(anchors,dim=1)
        cls_preds = concat_preds(cls_preds)
        cls_preds = cls_preds.reshape(
            cls_preds.shape[0],-1,self.num_classes+1)
        bbox_preds = concat_preds(bbox_preds)
        return anchors,cls_preds,bbox_preds

#训练模型
    #读取数据集和初始化
batch_size = 32
train_iter,_=d2l.load_data_bananas(batch_size)
device,net = d2l.try_gpu(),TinySSD(num_classes=1)
trainer = torch.optim.SGD(net.parameters(),lr=0.2,weight_decay=5e-4)

#定义损失函数和评价函数

cls_loss = nn.CrossEntropyLoss(reduction='none')
bbox_loss = nn.L1Loss(reduction='none')

def calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks):
    batch_size,num_classes = cls_preds.shape[0],cls_preds.shape[2]
    cls = cls_loss(cls_preds.reshape(-1,num_classes),
    cls_labels.reshape(-1)).reshape(batch_size,-1).mean(dim=1)
    bbox = bbox_loss(bbox_preds*bbox_masks).mean(dim=1)
    return cls+bbox

#使用平均绝对误差来评价边界框的预测结果
def cls_eval(cls_preds,cls_labels):
    return float((cls_preds.argmax(dim=-1).type(cls_preds.dtype)==cls_labels).sum())

def bbox_eval(bbox_preds,bbox_labels,bbox_masks):
    return float((torch.abs((bbox_labels-bbox_preds)*bbox_masks)).sum()) 


#训练模型
num_epochs,timer = 20,d2l.Timer()
animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],
legend=['class error','bbox mae'])

net = net.to(device)
for epoch in range(num_epochs):
    metric = d2l.Accumulator(4)
    net.train()
    for features ,target in train_iter:
        timer.start()
        trainer.zero_grad()
        X,Y = features.to(device),target.to(device)
        anchors,cls_preds,bbox_preds = net(X)
        bbox_labels,bbox_masks,cls_labels = d2l.multibox_target(anchors,Y)
        l = calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks)
        l.mean().backward()
        trainer.step()
        metric.add(cls_eval(cls_preds,cls_labels),cls_labels.numel(),
        bbox_eval(bbox_preds,bbox_labels,bbox_masks),bbox_labels.numel())
    cls_err,bbox_mae = 1-metric[0]/metric[1],metric[2]/metric[3]
    animator.add(epoch+1,(cls_err,bbox_mae))

#预测目标
X = torchvision.io.read_image('../img/banana.jpg').unsqueeze(0).float()
img = X.squeeze(0).permute(1,2,0).long()

def predict(X):
    net.eval()
    anchors,cls_preds,bbox_preds = net(X.to(device))

 

posted @ 2021-11-05 20:44  Maggieisxin  阅读(142)  评论(0编辑  收藏  举报