目标检测算法学习:SSD
from sys import gettrace from typing import ForwardRef import torch from torch.nn.modules import flatten import torchvision from torch import nn from torch.nn import functional as F from d2l import torch as d2l #来自于李沐深度学习课程https://zh-v2.d2l.ai/chapter_preface/index.html #类别预测层 def cls_predictor(num_inputs,num_anchors,num_classes): return nn.Conv2d(num_inputs,num_anchors*(num_classes+1), kernel_size=3,padding=1) #边界框预测层 def bbox_predictor(num_inputs,num_anchors): return nn.Conv2d(num_inputs,num_anchors*4,kernel_size=3,padding=1) #连接多尺度预测层 def forward(x,block): return block(x) # Y1 = forward(torch.zeros((2,8,20,20)),cls_predictor(8,5,10)) def flatten_pred(pred): return torch.flatten(pred.permute(0,2,3,1),start_dim=1) def concat_preds(preds): return torch.cat([flatten_pred(p) for p in preds],dim=1) #高宽减半 def down_sample_blk(in_channels,out_channels): blk=[] for _ in range(2): blk.append(nn.Conv2d(in_channels,out_channels, kernel_size=3,padding=1)) blk.append(nn.BatchNorm2d(out_channels)) blk.append(nn.ReLU()) in_channels = out_channels blk.append(nn.MaxPool2d(2)) return nn.Sequential(*blk) #基本网络块,用于从输入图像中抽取特征,网络块输出的特征图为:32*32(256/(2*2*2)=32) def base_net(): blk=[] num_filters = [3,16,32,64] for i in range(len(num_filters)-1): blk.append(down_sample_blk(num_filters[i],num_filters[i+1])) return nn.Sequential(*blk) def get_blk(i): if i==0: blk=base_net() elif i==1: blk=down_sample_blk(64,128) elif i==4: blk = nn.AdaptiveAvgPool2d((1,1)) else: blk = down_sample_blk(128,128) return blk #为每一个块定义前向计算, 输出包括:特征图、生成的锚框,预测的锚框的类别和偏移量 def blk_forward(X,blk,size,ratio,cls_predictor,bbox_predictor): Y = blk(X) anchors = d2l.multibox_prior(Y,sizes=size,ratios=ratio) cls_preds = cls_predictor(Y) bbox_preds = bbox_predictor(Y) return(Y,anchors,cls_preds,bbox_preds) #定义完整的模型 class TinySSD(nn.Module): def __init__(self,num_classes,**kwargs): super(TinySSD,self).__init__(**kwargs) self.num_classes = num_classes idx_to_in_channels = [64,128,128,128,128] for i in range(5): #即赋值语句 setattr(self,f'blk_{i}',get_blk(i)) setattr(self,f'cls_{i}',cls_predictor(idx_to_in_channels[i], num_anchors,num_classes)) setattr(self,f'bbox_[i]',bbox_predictor(idx_to_in_channels[i], num_anchors)) def forward(self,X): anchors,cls_preds,bbox_preds = [None]*5,[None]*5,[None]*5 for i in range(5): X,anchors[i],cls_preds[i],bbox_preds[i] = blk_forward( X,getattr(self,f'blk_{i}'),sizes[i],ratios[i], getattr(self,f'cls_{i}'),getattr(self,f'bbox_{i}') ) anchors = torch.cat(anchors,dim=1) cls_preds = concat_preds(cls_preds) cls_preds = cls_preds.reshape( cls_preds.shape[0],-1,self.num_classes+1) bbox_preds = concat_preds(bbox_preds) return anchors,cls_preds,bbox_preds #训练模型 #读取数据集和初始化 batch_size = 32 train_iter,_=d2l.load_data_bananas(batch_size) device,net = d2l.try_gpu(),TinySSD(num_classes=1) trainer = torch.optim.SGD(net.parameters(),lr=0.2,weight_decay=5e-4) #定义损失函数和评价函数 cls_loss = nn.CrossEntropyLoss(reduction='none') bbox_loss = nn.L1Loss(reduction='none') def calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks): batch_size,num_classes = cls_preds.shape[0],cls_preds.shape[2] cls = cls_loss(cls_preds.reshape(-1,num_classes), cls_labels.reshape(-1)).reshape(batch_size,-1).mean(dim=1) bbox = bbox_loss(bbox_preds*bbox_masks).mean(dim=1) return cls+bbox #使用平均绝对误差来评价边界框的预测结果 def cls_eval(cls_preds,cls_labels): return float((cls_preds.argmax(dim=-1).type(cls_preds.dtype)==cls_labels).sum()) def bbox_eval(bbox_preds,bbox_labels,bbox_masks): return float((torch.abs((bbox_labels-bbox_preds)*bbox_masks)).sum()) #训练模型 num_epochs,timer = 20,d2l.Timer() animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs], legend=['class error','bbox mae']) net = net.to(device) for epoch in range(num_epochs): metric = d2l.Accumulator(4) net.train() for features ,target in train_iter: timer.start() trainer.zero_grad() X,Y = features.to(device),target.to(device) anchors,cls_preds,bbox_preds = net(X) bbox_labels,bbox_masks,cls_labels = d2l.multibox_target(anchors,Y) l = calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks) l.mean().backward() trainer.step() metric.add(cls_eval(cls_preds,cls_labels),cls_labels.numel(), bbox_eval(bbox_preds,bbox_labels,bbox_masks),bbox_labels.numel()) cls_err,bbox_mae = 1-metric[0]/metric[1],metric[2]/metric[3] animator.add(epoch+1,(cls_err,bbox_mae)) #预测目标 X = torchvision.io.read_image('../img/banana.jpg').unsqueeze(0).float() img = X.squeeze(0).permute(1,2,0).long() def predict(X): net.eval() anchors,cls_preds,bbox_preds = net(X.to(device))