优化改进_基于条件GAN的陶瓷纹样生成
一、代码调整:调整损失函数的计算公式,增加感知损失perceptual_loss
1. 新建models/perceptual_loss.py
import torch
import torch.nn as nn
import torchvision.models as models
class PerceptualLoss(nn.Module):
def __init__(self, device):
super(PerceptualLoss, self).__init__()
vgg = models.vgg16(pretrained=True).features
# 只取前4层(非常重要:省显存)
self.vgg = nn.Sequential(*list(vgg[:4])).to(device)
for param in self.vgg.parameters():
param.requires_grad = False
self.criterion = nn.L1Loss()
def forward(self, fake, real):
# VGG要求输入是 [-1,1] → [0,1]
fake = (fake + 1) / 2.0
real = (real + 1) / 2.0
fake_f = self.vgg(fake)
real_f = self.vgg(real)
loss = self.criterion(fake_f, real_f)
return loss
2. 修改models/pix2pix_model.py
## 文件顶部增加
from models.perceptual_loss import PerceptualLoss
## init方法中初始化loss
self.lambda_perceptual = 10 # 可调(建议5~10)
self.criterionPerceptual = PerceptualLoss(self.device)
## backward_G方法中调整损失函数的计算方式
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
# self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
# self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
# self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
self.loss_perceptual = self.criterionPerceptual(self.fake_B, self.real_B) * self.lambda_perceptual
self.loss_G = self.loss_G_GAN + self.opt.lambda_L1 * self.loss_G_L1 + self.loss_perceptual
self.loss_G.backward()
3.打开options/train_options.py注册参数perceptual_loss
parser.add_argument('--lambda_perceptual', type=float, default=5.0)
二、系统架构图调整
- 调整图中的Generator(生成器)和Discriminator(判别器)为 Generator(U-Net) 和 Discriminator(PatchGAN)。
- PatchGAN作用:①强调局部纹理;②适合纹样/texture任务;③Pix2Pix核心设计;
本文来自博客园,作者:jsqup,转载请注明原文链接:https://www.cnblogs.com/jsqup/p/20676333

浙公网安备 33010602011771号