木槿花篱

嘻嘻,欢迎~

全网最详细的深度学习经典模型RESNET解析【京东特邀专家 朱利明】(bilibili视频学习)(代码解析)

这是一篇学习记录贴

  1 import torch
  2 import torch.nn as nn
  3 from .utils import load_state_dict_from_url
  4 
  5 
  6 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
  7            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
  8            'wide_resnet50_2', 'wide_resnet101_2']
  9 
 10 
 11 model_urls = {
 12     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 13     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 14     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 15     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 16     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
 17     'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
 18     'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
 19     'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
 20     'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
 21 }
 22 
 23 # 封装
 24 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
 25     """3x3 convolution with padding"""
 26     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
 27                      padding=dilation, groups=groups, bias=False, dilation=dilation)
 28 
 29 
 30 def conv1x1(in_planes, out_planes, stride=1):
 31     """1x1 convolution"""
 32     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 33 
 34 #block定义
 35 class BasicBlock(nn.Module):
 36     expansion = 1
 37     #定义
 38     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
 39                  base_width=64, dilation=1, norm_layer=None):
 40         super(BasicBlock, self).__init__()
 41         if norm_layer is None:
 42             norm_layer = nn.BatchNorm2d
 43         if groups != 1 or base_width != 64:
 44             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
 45         if dilation > 1:
 46             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
 47         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
 48         self.conv1 = conv3x3(inplanes, planes, stride)
 49         self.bn1 = norm_layer(planes) # 归一化
 50         self.relu = nn.ReLU(inplace=True)
 51         self.conv2 = conv3x3(planes, planes)
 52         self.bn2 = norm_layer(planes)
 53         self.downsample = downsample
 54         self.stride = stride
 55     #实现
 56     def forward(self, x):
 57         # 保存x 做残差
 58         identity = x
 59 
 60         out = self.conv1(x)
 61         out = self.bn1(out)
 62         out = self.relu(out)
 63 
 64         out = self.conv2(out)
 65         out = self.bn2(out)
 66 
 67         # 下采样
 68         if self.downsample is not None:
 69             identity = self.downsample(x)
 70 
 71         out += identity # 先和 X 融合,再做relu
 72         out = self.relu(out)
 73 
 74         return out
 75 
 76 # 瓶颈(50层以上)
 77 class Bottleneck(nn.Module):
 78     # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
 79     # while original implementation places the stride at the first 1x1 convolution(self.conv1)
 80     # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
 81     # This variant is also known as ResNet V1.5 and improves accuracy according to
 82     # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
 83 
 84     # 通道放大的倍数
 85     expansion = 4
 86 
 87     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
 88                  base_width=64, dilation=1, norm_layer=None):
 89         super(Bottleneck, self).__init__()
 90         if norm_layer is None:
 91             norm_layer = nn.BatchNorm2d
 92         width = int(planes * (base_width / 64.)) * groups
 93         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
 94         self.conv1 = conv1x1(inplanes, width)
 95         self.bn1 = norm_layer(width)
 96         self.conv2 = conv3x3(width, width, stride, groups, dilation)
 97         self.bn2 = norm_layer(width)
 98         self.conv3 = conv1x1(width, planes * self.expansion)
 99         self.bn3 = norm_layer(planes * self.expansion)
100         self.relu = nn.ReLU(inplace=True)
101         self.downsample = downsample
102         self.stride = stride
103 
104     # 网络前向传播过程(调用过程)
105     def forward(self, x):
106         identity = x
107 
108         out = self.conv1(x)
109         out = self.bn1(out)
110         out = self.relu(out)
111 
112         out = self.conv2(out)
113         out = self.bn2(out)
114         out = self.relu(out)
115 
116         out = self.conv3(out)
117         out = self.bn3(out)
118 
119         if self.downsample is not None:
120             identity = self.downsample(x)
121 
122         out += identity
123         out = self.relu(out)
124 
125         return out
126 
127 
128 class ResNet(nn.Module):
129 
130     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
131                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
132                  norm_layer=None):
133         super(ResNet, self).__init__()
134         if norm_layer is None:
135             norm_layer = nn.BatchNorm2d
136         self._norm_layer = norm_layer
137 
138         self.inplanes = 64
139         self.dilation = 1
140         if replace_stride_with_dilation is None:
141             # each element in the tuple indicates if we should replace
142             # the 2x2 stride with a dilated convolution instead
143             replace_stride_with_dilation = [False, False, False]
144         if len(replace_stride_with_dilation) != 3:
145             raise ValueError("replace_stride_with_dilation should be None "
146                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
147         self.groups = groups
148         self.base_width = width_per_group
149         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
150                                bias=False)
151         self.bn1 = norm_layer(self.inplanes)
152         self.relu = nn.ReLU(inplace=True)
153         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
154         #每一层stage
155         self.layer1 = self._make_layer(block, 64, layers[0])
156         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
157                                        dilate=replace_stride_with_dilation[0])
158         self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
159                                        dilate=replace_stride_with_dilation[1])
160         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
161                                        dilate=replace_stride_with_dilation[2])
162         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
163         self.fc = nn.Linear(512 * block.expansion, num_classes)
164 
165         #参数初始化
166         for m in self.modules():
167             if isinstance(m, nn.Conv2d):
168                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
169             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
170                 nn.init.constant_(m.weight, 1)
171                 nn.init.constant_(m.bias, 0)
172 
173         # Zero-initialize the last BN in each residual branch,
174         # so that the residual branch starts with zeros, and each residual block behaves like an identity.
175         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
176         if zero_init_residual:
177             for m in self.modules():
178                 if isinstance(m, Bottleneck):
179                     nn.init.constant_(m.bn3.weight, 0)
180                 elif isinstance(m, BasicBlock):
181                     nn.init.constant_(m.bn2.weight, 0)
182 
183     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
184         norm_layer = self._norm_layer
185         downsample = None
186         previous_dilation = self.dilation
187         if dilate:
188             self.dilation *= stride
189             stride = 1
190         if stride != 1 or self.inplanes != planes * block.expansion:
191             downsample = nn.Sequential(
192                 conv1x1(self.inplanes, planes * block.expansion, stride),
193                 norm_layer(planes * block.expansion),
194             )
195 
196         layers = []
197         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
198                             self.base_width, previous_dilation, norm_layer))
199         self.inplanes = planes * block.expansion
200         for _ in range(1, blocks):
201             layers.append(block(self.inplanes, planes, groups=self.groups,
202                                 base_width=self.base_width, dilation=self.dilation,
203                                 norm_layer=norm_layer))
204 
205         return nn.Sequential(*layers)
206 
207     def _forward_impl(self, x):
208         # See note [TorchScript super()]
209         x = self.conv1(x)
210         x = self.bn1(x)
211         x = self.relu(x)
212         x = self.maxpool(x)
213 
214         x = self.layer1(x)
215         x = self.layer2(x)
216         x = self.layer3(x)
217         x = self.layer4(x)
218 
219         x = self.avgpool(x)
220         x = torch.flatten(x, 1)
221         x = self.fc(x)
222 
223         return x
224 
225     def forward(self, x):
226         return self._forward_impl(x)
227 
228 
229 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
230     model = ResNet(block, layers, **kwargs)
231     if pretrained:
232         state_dict = load_state_dict_from_url(model_urls[arch],
233                                               progress=progress)
234         model.load_state_dict(state_dict)
235     return model
236 
237 
238 def resnet18(pretrained=False, progress=True, **kwargs):
239     r"""ResNet-18 model from
240     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
241     Args:
242         pretrained (bool): If True, returns a model pre-trained on ImageNet
243         progress (bool): If True, displays a progress bar of the download to stderr
244     """
245     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
246                    **kwargs)
247 
248 
249 def resnet34(pretrained=False, progress=True, **kwargs):
250     r"""ResNet-34 model from
251     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
252     Args:
253         pretrained (bool): If True, returns a model pre-trained on ImageNet
254         progress (bool): If True, displays a progress bar of the download to stderr
255     """
256     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
257                    **kwargs)
258 
259 
260 def resnet50(pretrained=False, progress=True, **kwargs):
261     r"""ResNet-50 model from
262     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
263     Args:
264         pretrained (bool): If True, returns a model pre-trained on ImageNet
265         progress (bool): If True, displays a progress bar of the download to stderr
266     """
267     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
268                    **kwargs)
269 
270 
271 def resnet101(pretrained=False, progress=True, **kwargs):
272     r"""ResNet-101 model from
273     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
274     Args:
275         pretrained (bool): If True, returns a model pre-trained on ImageNet
276         progress (bool): If True, displays a progress bar of the download to stderr
277     """
278     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
279                    **kwargs)
280 
281 
282 def resnet152(pretrained=False, progress=True, **kwargs):
283     r"""ResNet-152 model from
284     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
285     Args:
286         pretrained (bool): If True, returns a model pre-trained on ImageNet
287         progress (bool): If True, displays a progress bar of the download to stderr
288     """
289     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
290                    **kwargs)
291 
292 
293 def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
294     r"""ResNeXt-50 32x4d model from
295     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
296     Args:
297         pretrained (bool): If True, returns a model pre-trained on ImageNet
298         progress (bool): If True, displays a progress bar of the download to stderr
299     """
300     kwargs['groups'] = 32
301     kwargs['width_per_group'] = 4
302     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
303                    pretrained, progress, **kwargs)
304 
305 
306 def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
307     r"""ResNeXt-101 32x8d model from
308     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
309     Args:
310         pretrained (bool): If True, returns a model pre-trained on ImageNet
311         progress (bool): If True, displays a progress bar of the download to stderr
312     """
313     kwargs['groups'] = 32
314     kwargs['width_per_group'] = 8
315     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
316                    pretrained, progress, **kwargs)
317 
318 
319 def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
320     r"""Wide ResNet-50-2 model from
321     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
322     The model is the same as ResNet except for the bottleneck number of channels
323     which is twice larger in every block. The number of channels in outer 1x1
324     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
325     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
326     Args:
327         pretrained (bool): If True, returns a model pre-trained on ImageNet
328         progress (bool): If True, displays a progress bar of the download to stderr
329     """
330     kwargs['width_per_group'] = 64 * 2
331     return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
332                    pretrained, progress, **kwargs)
333 
334 
335 def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
336     r"""Wide ResNet-101-2 model from
337     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
338     The model is the same as ResNet except for the bottleneck number of channels
339     which is twice larger in every block. The number of channels in outer 1x1
340     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
341     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
342     Args:
343         pretrained (bool): If True, returns a model pre-trained on ImageNet
344         progress (bool): If True, displays a progress bar of the download to stderr
345     """
346     kwargs['width_per_group'] = 64 * 2
347     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
348                    pretrained, progress, **kwargs)

 

posted @ 2020-07-28 15:42  木槿花篱  阅读(398)  评论(0编辑  收藏  举报