【cv】cycleGAN代码解析:test.py
"""通用的图像到图像转换测试脚本。
当你使用train.py训练好模型后,可以使用此脚本来测试模型。
它会从'--checkpoints_dir'加载保存的模型,并将结果保存到'--results_dir'。
脚本首先根据选项创建模型和数据集。它会硬编码一些参数。
然后对'--num_test'张图像运行推理,并将结果保存到HTML文件中。
示例(你需要先训练模型或从我们的网站下载预训练模型):
测试CycleGAN模型(双向转换):
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
测试CycleGAN模型(仅单向转换):
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
选项'--model test'用于仅生成CycleGAN的单向结果。
此选项会自动设置'--dataset_mode single',即只从一个集合加载图像。
相反,使用'--model cycle_gan'需要加载并生成双向结果,
这有时是不必要的。结果将保存到./results/目录。
使用'--results_dir <directory_path_to_save_result>'指定结果保存目录。
测试pix2pix模型:
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
更多测试选项参见options/base_options.py和options/test_options.py。
训练和测试技巧参见:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
常见问题参见:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import os
from pathlib import Path # 用于路径处理的库
from options.test_options import TestOptions # 导入测试选项类
from data import create_dataset # 导入创建数据集的函数
from models import create_model # 导入创建模型的函数
from util.visualizer import save_images # 导入保存图像的函数
from util import html # 用于生成HTML结果页面的工具
try:
import wandb # 尝试导入wandb(用于实验跟踪)
except ImportError:
# 如果导入失败,打印警告信息
print('警告:未找到wandb包。使用选项"--use_wandb"将导致错误。')
if __name__ == "__main__":
# 解析测试选项:获取命令行参数并生成选项实例
opt = TestOptions().parse()
# 硬编码测试所需的参数(测试代码的限制)
opt.num_threads = 0 # 测试代码仅支持0个线程
opt.batch_size = 1 # 测试代码仅支持批大小为1
opt.serial_batches = True # 禁用数据打乱;如果需要随机选择图像的结果,可注释此行
opt.no_flip = True # 不进行图像翻转;如果需要翻转图像的结果,可注释此行
opt.display_id = -1 # 不使用visdom显示;测试代码将结果保存到HTML文件
# 根据选项创建数据集(基于dataset_mode和其他参数)
dataset = create_dataset(opt)
# 根据选项创建模型(基于model参数和其他配置)
model = create_model(opt)
# 模型常规设置:加载并打印网络结构、创建调度器
model.setup(opt)
# 创建用于保存结果的网页目录
# 目录路径由结果根目录、实验名称、阶段和轮次组成
web_dir = Path(opt.results_dir) / opt.name / f"{opt.phase}_{opt.epoch}"
if opt.load_iter > 0: # 如果指定了加载迭代次数(默认0),则在目录名后添加迭代次数
web_dir = Path(f"{web_dir}_iter{opt.load_iter}")
print(f"创建网页目录 {web_dir}")
# 创建HTML对象,用于组织和保存结果
webpage = html.HTML(web_dir, f"实验 = {opt.name}, 阶段 = {opt.phase}, 轮次 = {opt.epoch}")
# 测试时使用评估模式。这仅影响batchnorm和dropout等层。
# 对于[pix2pix]:原始pix2pix使用batchnorm和dropout,可尝试开启/关闭eval()模式。
# 对于[CycleGAN]:不影响,因为CycleGAN使用instancenorm且无dropout。
if opt.eval:
model.eval()
# 遍历数据集进行测试
for i, data in enumerate(dataset):
# 只处理指定数量(num_test)的图像
if i >= opt.num_test:
break
model.set_input(data) # 从数据加载器中解析输入数据
model.test() # 运行推理(测试)
test方法的逻辑:
这部分定义在base_model.py里,相比train比较简单,只需要前向传播即可,不需要对网络进行反向传播优化
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
visuals = model.get_current_visuals() # 获取当前的图像结果
img_path = model.get_image_paths() # 获取当前图像的路径
# 每处理5张图像,打印一次进度信息
if i % 5 == 0:
print(f"正在处理第({i:04d})张图像... {img_path}")
# 将图像保存到HTML页面中
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save() # 保存HTML页面(最终结果汇总)

浙公网安备 33010602011771号