优化改进_基于条件GAN的陶瓷纹样生成

一、代码调整:调整损失函数的计算公式,增加感知损失perceptual_loss

1. 新建models/perceptual_loss.py

import torch
import torch.nn as nn
import torchvision.models as models

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()

        vgg = models.vgg16(pretrained=True).features

        # 只取前4层(非常重要:省显存)
        self.vgg = nn.Sequential(*list(vgg[:4])).to(device)

        for param in self.vgg.parameters():
            param.requires_grad = False

        self.criterion = nn.L1Loss()

    def forward(self, fake, real):
        # VGG要求输入是 [-1,1] → [0,1]
        fake = (fake + 1) / 2.0
        real = (real + 1) / 2.0

        fake_f = self.vgg(fake)
        real_f = self.vgg(real)

        loss = self.criterion(fake_f, real_f)
        return loss

2. 修改models/pix2pix_model.py

## 文件顶部增加
from models.perceptual_loss import PerceptualLoss

## init方法中初始化loss
self.lambda_perceptual = 10  # 可调(建议5~10)
self.criterionPerceptual = PerceptualLoss(self.device)

## backward_G方法中调整损失函数的计算方式
    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        # self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        # self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        # self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        self.loss_perceptual = self.criterionPerceptual(self.fake_B, self.real_B) * self.lambda_perceptual
        self.loss_G = self.loss_G_GAN + self.opt.lambda_L1 * self.loss_G_L1 + self.loss_perceptual

        self.loss_G.backward()

3.打开options/train_options.py注册参数perceptual_loss

parser.add_argument('--lambda_perceptual', type=float, default=5.0)

二、系统架构图调整

  1. 调整图中的Generator(生成器)和Discriminator(判别器)为 Generator(U-Net) 和 Discriminator(PatchGAN)。
    • PatchGAN作用:①强调局部纹理;②适合纹样/texture任务;③Pix2Pix核心设计;
posted @ 2026-06-21 11:55  jsqup  阅读(2)  评论(0)    收藏  举报