捋一捋pytorch官方FasterRCNN代码
捋一捋pytorch官方FasterRCNN代码
目前 pytorch 已经在 torchvision 模块集成了 FasterRCNN 和 MaskRCNN 代码。考虑到帮助各位小伙伴理解模型细节问题,本文分析一下 FasterRCNN 代码,帮助新手理解 Two-Stage 检测中的主要问题。
这篇文章默认读者已经对 FasterRCNN 原理有一定了解。否则请先点击阅读上一篇文章:
torchvision 中 FasterRCNN 代码文档如下:
https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn
在 python 中装好 torchvision 后,输入以下命令即可查看版本和代码位置:
import torchvision
print(torchvision.__version__)
# '0.6.0'
print(torchvision.__path__)
# ['/usr/local/lib/python3.7/site-packages/torchvision']
△ 代码结构
图1
作为 torchvision 中目标检测基类,GeneralizedRCNN 继承了 torch.nn.Module,后续 FasterRCNN 、MaskRCNN 都继承 GeneralizedRCNN。
△ GeneralizedRCNN
GeneralizedRCNN 继承基类 nn.Module 。首先来看看基类 GeneralizedRCNN 的代码:
class GeneralizedRCNN(nn.Module):
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
# used only on torchscript mode
self._has_warned = False
@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
return detections
def forward(self, images, targets=None):
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
losses = 
