denoise、超分辨、gan网络基于Pytorch实现

denoise、超分辨、gan网络实现

一、去噪网络

import torch
import torch.nn as nn
import time
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image, to_tensor

from torchvision import transforms
# 128 x 3 x 64 x 64
size = 64

transform_image = transforms.Compose([
                                      transforms.RandomCrop(size),  # PIL Image
                                      transforms.ToTensor()  # Tensor
])

transform_test = transforms.Compose([
                                      transforms.ToTensor()  # Tensor
])

from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

class ImageFolder(Dataset):
    def __init__(self, path):
        super(ImageFolder, self).__init__()
        self.path = os.path.abspath(path)
        self.image_list = os.listdir(self.path)  # List

    def __getitem__(self, item):
        image_path = self.image_list[item]
        image_path = os.path.join(self.path, image_path)
        image_pil = Image.open(image_path)
        return transform_image(image_pil)

    def __len__(self):
        return len(self.image_list)
   
dataset = ImageFolder('/content/data')
loader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=2)

def add_noise(img_tensor, std_gaus):
    noise = torch.randn(img_tensor.shape).type_as(img_tensor) * std_gaus
    noised_img = img_tensor + noise
    return torch.clamp(noised_img, min=0., max=1.)

class CNNDenoiser(nn.Module):
    def __init__(self, num_channel=3, num_f=64):
        super(CNNDenoiser, self).__init__()
        # 20 x 3 x 64 x 64
        self.conv1 = nn.Conv2d(num_channel, num_f, 5, padding=2)
        # 20 x 64 x 64 x 64
        self.conv2 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 64 x 64
        self.conv3 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 64 x 64
        self.conv4 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 64 x 64
        self.conv5 = nn.Conv2d(num_f, num_channel, 5, padding=2)
        # 20 x 3 x 64 x 64

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = nn.functional.relu(self.conv4(x))
        x = nn.functional.relu(self.conv5(x))
        return x
    
num_epoches = 200
loss_func = nn.MSELoss()
model = CNNDenoiser() # Multiple Layer Perception
training_loader = loader

optimizer = torch.optim.Adam(model.parameters(), 0.001)

test_image = transform_test(Image.open('/content/data/im_16.bmp'))
transforms.functional.to_pil_image(test_image).save(f'/content/tests/GT.png')  # Groun Truth
test_image_noised = add_noise(test_image, 0.1)
transforms.functional.to_pil_image(test_image_noised).save(f'/content/tests/input.png')

model.cuda()
test_image_noised = test_image_noised.cuda()
loss_func.cuda()

#开始训练
for e in range(num_epoches):
    test_results = transforms.functional.to_pil_image(
        torch.clamp(model(test_image_noised.unsqueeze(0)).squeeze(0).cpu(), min=0., max=1.)
        )
    test_results.save(f'/content/tests/{e}.png')

    losses = []
    start_time = time.time()
    for i, data_batch in enumerate(training_loader):
        data = data_batch.cuda()
        noised_data = add_noise(data, 0.1)
        denoised = model(noised_data)
        loss = loss_func(denoised, data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().item())
    end_time = time.time()
    print(f'{e + 1} epoch. Training Avg Loss {sum(losses) / len(training_loader):.4f}, time {end_time - start_time:.2f} s')

二、超分辨网络

import torch
import torch.nn as nn
import time
from torchvision.transforms.functional import to_pil_image, to_tensor

from torchvision import transforms
# 128 x 3 x 64 x 64
size = 64

transform_image = transforms.Compose([
                                      transforms.RandomCrop(size),  # PIL Image
                                      transforms.ToTensor()  # Tensor
])

transform_test = transforms.Compose([
                                      transforms.ToTensor()  # Tensor
])

from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

class ImageFolderSR(Dataset):
    def __init__(self, path):
        super(ImageFolderSR, self).__init__()
        self.path = os.path.abspath(path)
        self.image_list = os.listdir(self.path)  # List

    def __getitem__(self, item):
        image_path = self.image_list[item]
        image_path = os.path.join(self.path, image_path)
        image_pil = Image.open(image_path)
        return transform_image(image_pil)

    def __len__(self):
        return len(self.image_list)
    
dataset = ImageFolderSR('/content/data')
loader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=2)

def add_noise(img_tensor, std_gaus):
    noise = torch.randn(img_tensor.shape).type_as(img_tensor) * std_gaus
    noised_img = img_tensor + noise
    return torch.clamp(noised_img, min=0., max=1.)

# bilinear & bicubic
def resample_bic(img_tensor, sr_factor):
    return nn.functional.interpolate(img_tensor, scale_factor=sr_factor, mode='bicubic')


class SRCNN(nn.Module):
    def __init__(self, num_channel=3, sr_factor=4, num_f=64):
        super(SRCNN, self).__init__()
        # 20 x 3 x 16 x 16
        self.conv1 = nn.Conv2d(num_channel, num_f, 5, padding=2)
        # 20 x 64 x 16 x 16
        self.conv2 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 16 x 16
        self.conv3 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 16 x 16
        self.conv4 = nn.Conv2d(num_f, num_f, 3, padding=1)
        # 20 x 64 x 16 x 16
        self.conv5 = nn.Conv2d(num_f, num_f * sr_factor ** 2, 3, padding=1)
        # 20 x 1024 x 16 x 16
        self.upsample = nn.PixelShuffle(sr_factor)  # self.upsample = nn.ConvTranspose2d()
        # 20 x 64 x 64 x 64
        self.conv6 = nn.Conv2d(num_f, num_channel, 5, padding=2)
        # 20 x 3 x 64 x 64

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = nn.functional.relu(self.conv4(x))
        x = nn.functional.relu(self.conv5(x))
        x = self.upsample(x)
        x = self.conv6(x)
        return x
    
class LPIPS_Loss(torch.nn.Module):
    def __init__(self, net_str='alex', lpips_weight=1.0):  # hpyer-parameters
        super(LPIPS_Loss, self).__init__()
        self.lpips = lpips.LPIPS(net=net_str)
        self.l1norm = nn.L1Loss()
        self.lpips_weight = lpips_weight

    def forward(self, img0, img1):
        # 20 x 3 x 64 x 64
        lpips_score = torch.mean(self.lpips(img0, img1))  # 20 x 1
        l1_norm = self.l1norm(img0, img1)
        return lpips_score * self.lpips_weight + l1_norm
    
num_epoches = 1000
loss_func = LPIPS_Loss(lpips_weight=10.0)
model = SRCNN()
training_loader = loader

optimizer = torch.optim.Adam(model.parameters(), 0.001)

test_image = transform_test(Image.open('/content/data/im_16.bmp'))
transforms.functional.to_pil_image(test_image).save(f'/content/test2/GT.png')  # Groun Truth
test_image_down = resample_bic(test_image.unsqueeze(0), 1/4)
test_image_vis = torch.clamp(resample_bic(test_image_down, 4), min=0., max=1.)
transforms.functional.to_pil_image(test_image_down.squeeze(0)).save(f'/content/test2/input.png')
transforms.functional.to_pil_image(test_image_vis.squeeze(0)).save(f'/content/test2/bic.png')


model.cuda()
test_image_down = test_image_down.cuda()
loss_func.cuda()

for e in range(num_epoches):
    if e % 20:
        test_results = transforms.functional.to_pil_image(
            torch.clamp(
                model(test_image_down).squeeze(0).cpu(), min=0., max=1.)
            )
        test_results.save(f'/content/test2/{e}.png')

    losses = []
    start_time = time.time()
    for i, data_batch in enumerate(training_loader):
        data = data_batch.cuda()
        low_resolution_data = resample_bic(data, 1/4)
        sr = model(low_resolution_data)
        loss = loss_func(sr, data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().item())
    end_time = time.time()
    print(f'{e + 1} epoch. Training Avg Loss {sum(losses) / len(training_loader):.4f}, time {end_time - start_time:.2f} s')

三、对抗生成网络

import torch
import torch.nn as nn
import time
from torchvision.transforms.functional import to_pil_image, to_tensor

import torchvision
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

# Transformers for Celeb-A
transform_rgb = transforms.Compose([transforms.CenterCrop((170,170)),
                                    transforms.Resize((64,64)),
                                    transforms.ToTensor()])

# Celeb-A
celeb_dataset = datasets.CelebA('/content', split='train', download=True, transform=transform_rgb)

celeb_dataloader = DataLoader(celeb_dataset, batch_size=128, shuffle=True, num_workers=2)


transforms.functional.to_pil_image(celeb_dataset[11][0])

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu=1, nz=128, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            # batch_size x 128 x 1 x 1
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, ngpu=1, ndf=64, nc=3):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
import torch.optim as optim

# Create the generator
netG = Generator().to('cuda')
# Apply the weights_init function to randomly initialize all weights
netG.apply(weights_init)

# Create the Discriminator
netD = Discriminator().to('cuda')
# Apply the weights_init function to randomly initialize all weights
netD.apply(weights_init)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(64, 128, 1, 1, device='cuda')  # 64 x 128 x 1 x 1

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training Loop
num_epochs = 40

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(celeb_dataloader, 0):


        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()

        # Format batch
        real_cpu = data[0].to('cuda')
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device='cuda')

        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, 128, 1, 1, device='cuda')
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()


        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(celeb_dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(celeb_dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

        iters += 1
        
transforms.functional.to_pil_image(img_list[-1])
posted @ 2021-09-15 17:22  梁君牧  阅读(571)  评论(0编辑  收藏  举报