pytorch 不使用转置卷积来实现上采样

上采样(upsampling)一般包括2种方式:

第二种方法如何用pytorch实现可见上面的链接

 

这里想要介绍的是如何使用pytorch实现第一种方法:

 

举例:

1)使用torch.nn模块实现一个生成器为:

import torch.nn as nn
import torch.nn.functional as F


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv(out)

        return out

class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()
        self.in_channels = in_channels

        self.encoder = nn.Sequential(
            ConvLayer(self.in_channels, 32, 3, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            ConvLayer(32, 64, 3, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ConvLayer(64, 128, 3, 2),
        )

        upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.decoder = nn.Sequential(
            upsample,
            nn.Conv2d(128, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            upsample,
            nn.Conv2d(64, 32, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            upsample,
            nn.Conv2d(32, 3, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        out = self.decoder(x)

        return out

def test():
    net = Generator(3)
    for module in net.children():
        print(module)
    x = Variable(torch.randn(2,3,224,224))
    output = net(x)
    print('output :', output.size())
    print(type(output))

if __name__ == '__main__':
    test()
View Code

返回:

model.py .Sequential(
  (0): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
)
Sequential(
  (0): Upsample(scale_factor=2, mode=bilinear)
  (1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
  (4): Upsample(scale_factor=2, mode=bilinear)
  (5): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Upsample(scale_factor=2, mode=bilinear)
  (9): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
  (10): Tanh()
)
output : torch.Size([2, 3, 224, 224])
<class 'torch.Tensor'>
View Code

但是这个会有警告:

 UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.

 

可使用torch.nn.functional模块替换为:

import torch.nn as nn
import torch.nn.functional as F


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv(out)

        return out

class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()
        self.in_channels = in_channels

        self.encoder = nn.Sequential(
            ConvLayer(self.in_channels, 32, 3, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            ConvLayer(32, 64, 3, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ConvLayer(64, 128, 3, 2),
        )

        self.decoder1 = nn.Sequential(
            nn.Conv2d(128, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.decoder2 = nn.Sequential(
            nn.Conv2d(64, 32, 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.decoder3 = nn.Sequential(
            nn.Conv2d(32, 3, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.decoder1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.decoder2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        out = self.decoder3(x)

        return out

def test():
    net = Generator(3)
    for module in net.children():
        print(module)
    x = Variable(torch.randn(2,3,224,224))
    output = net(x)
    print('output :', output.size())
    print(type(output))

if __name__ == '__main__':
    test()
View Code

返回:

model.py .Sequential(
  (0): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
)
Sequential(
  (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Sequential(
  (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Sequential(
  (0): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
  (1): Tanh()
)
output : torch.Size([2, 3, 224, 224])
<class 'torch.Tensor'>
View Code

 

posted @ 2019-08-23 16:28  慢行厚积  阅读(8585)  评论(0编辑  收藏  举报