Loading

【cv】cycleGAN代码解析:test.py

"""通用的图像到图像转换测试脚本。

当你使用train.py训练好模型后,可以使用此脚本来测试模型。
它会从'--checkpoints_dir'加载保存的模型,并将结果保存到'--results_dir'。

脚本首先根据选项创建模型和数据集。它会硬编码一些参数。
然后对'--num_test'张图像运行推理,并将结果保存到HTML文件中。

示例(你需要先训练模型或从我们的网站下载预训练模型):
    测试CycleGAN模型(双向转换):
        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

    测试CycleGAN模型(仅单向转换):
        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

    选项'--model test'用于仅生成CycleGAN的单向结果。
    此选项会自动设置'--dataset_mode single',即只从一个集合加载图像。
    相反,使用'--model cycle_gan'需要加载并生成双向结果,
    这有时是不必要的。结果将保存到./results/目录。
    使用'--results_dir <directory_path_to_save_result>'指定结果保存目录。

    测试pix2pix模型:
        python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

更多测试选项参见options/base_options.py和options/test_options.py。
训练和测试技巧参见:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
常见问题参见:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""

import os
from pathlib import Path  # 用于路径处理的库
from options.test_options import TestOptions  # 导入测试选项类
from data import create_dataset  # 导入创建数据集的函数
from models import create_model  # 导入创建模型的函数
from util.visualizer import save_images  # 导入保存图像的函数
from util import html  # 用于生成HTML结果页面的工具

try:
    import wandb  # 尝试导入wandb(用于实验跟踪)
except ImportError:
    # 如果导入失败,打印警告信息
    print('警告:未找到wandb包。使用选项"--use_wandb"将导致错误。')


if __name__ == "__main__":
    # 解析测试选项:获取命令行参数并生成选项实例
    opt = TestOptions().parse()  
    # 硬编码测试所需的参数(测试代码的限制)
    opt.num_threads = 0  # 测试代码仅支持0个线程
    opt.batch_size = 1  # 测试代码仅支持批大小为1
    opt.serial_batches = True  # 禁用数据打乱;如果需要随机选择图像的结果,可注释此行
    opt.no_flip = True  # 不进行图像翻转;如果需要翻转图像的结果,可注释此行
    opt.display_id = -1  # 不使用visdom显示;测试代码将结果保存到HTML文件
    
    # 根据选项创建数据集(基于dataset_mode和其他参数)
    dataset = create_dataset(opt)  
    # 根据选项创建模型(基于model参数和其他配置)
    model = create_model(opt)  
    # 模型常规设置:加载并打印网络结构、创建调度器
    model.setup(opt)  

    # 创建用于保存结果的网页目录
    # 目录路径由结果根目录、实验名称、阶段和轮次组成
    web_dir = Path(opt.results_dir) / opt.name / f"{opt.phase}_{opt.epoch}"  
    if opt.load_iter > 0:  # 如果指定了加载迭代次数(默认0),则在目录名后添加迭代次数
        web_dir = Path(f"{web_dir}_iter{opt.load_iter}")
    print(f"创建网页目录 {web_dir}")
    # 创建HTML对象,用于组织和保存结果
    webpage = html.HTML(web_dir, f"实验 = {opt.name}, 阶段 = {opt.phase}, 轮次 = {opt.epoch}")
    
    # 测试时使用评估模式。这仅影响batchnorm和dropout等层。
    # 对于[pix2pix]:原始pix2pix使用batchnorm和dropout,可尝试开启/关闭eval()模式。
    # 对于[CycleGAN]:不影响,因为CycleGAN使用instancenorm且无dropout。
    if opt.eval:
        model.eval()
    
    # 遍历数据集进行测试
    for i, data in enumerate(dataset):
        # 只处理指定数量(num_test)的图像
        if i >= opt.num_test:  
            break
        model.set_input(data)  # 从数据加载器中解析输入数据
        model.test()  # 运行推理(测试)

test方法的逻辑:

这部分定义在base_model.py里,相比train比较简单,只需要前向传播即可,不需要对网络进行反向传播优化

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()
        visuals = model.get_current_visuals()  # 获取当前的图像结果
        img_path = model.get_image_paths()  # 获取当前图像的路径
        # 每处理5张图像,打印一次进度信息
        if i % 5 == 0:  
            print(f"正在处理第({i:04d})张图像... {img_path}")
        # 将图像保存到HTML页面中
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    
    webpage.save()  # 保存HTML页面(最终结果汇总)
posted @ 2025-09-25 16:46  SaTsuki26681534  阅读(18)  评论(0)    收藏  举报