Loading

【看代码】SRGAN的test.py代码逐行解析

# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# 版权所有 2022 Dakewe Biotech Corporation。保留所有权利。
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据Apache License 2.0许可证授权;
#   you may not use this file except in compliance with the License.
#   除非符合许可证条款,否则不得使用本文件。
#   You may obtain a copy of the License at
#   您可以在以下位置获取许可证副本:
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# 除非法律要求或书面同意,软件
# distributed under the License is distributed on an "AS IS" BASIS,
# 根据许可证分发的软件按"原样"提供,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不提供任何明示或暗示的担保或条件。
# See the License for the specific language governing permissions and
# 有关权限和限制的具体规定,请参见许可证。
# limitations under the License.
# ==============================================================================
# 导入必要的库
import argparse  # 用于解析命令行参数
import os  # 用于文件路径和目录操作
import time  # 用于计时
from typing import Any  # 用于类型提示

import cv2  # OpenCV库,用于图像处理
import torch  # PyTorch深度学习框架
import yaml  # 用于读取YAML配置文件
from torch import nn  # PyTorch的神经网络模块
from torch.utils.data import DataLoader  # 用于数据加载的工具

import model  # 导入自定义的模型模块
from dataset import CUDAPrefetcher, PairedImageDataset  # 导入自定义数据集和数据预取器
from imgproc import tensor_to_image  # 导入将张量转换为图像的工具函数
from utils import (build_iqa_model,  # 导入构建图像质量评估模型的工具
                   load_pretrained_state_dict,  # 导入加载预训练模型权重的工具
                   make_directory,  # 导入创建目录的工具
                   AverageMeter,  # 导入用于计算平均值的工具类
                   ProgressMeter,  # 导入进度条工具类
                   Summary)  # 导入用于指标汇总的工具类


def load_dataset(config: Any, device: torch.device) -> CUDAPrefetcher:
    """
    加载测试数据集并返回CUDA预取器(加速数据加载)
    
    参数:
        config: 配置信息(从YAML文件读取)
        device: 计算设备(CPU或GPU)
    
    返回:
        CUDAPrefetcher: 用于加速数据加载的预取器
    """
    # 实例化成对图像数据集(包含高分辨率GT图像和低分辨率LR图像)
    test_datasets = PairedImageDataset(
        config["TEST"]["DATASET"]["PAIRED_TEST_GT_IMAGES_DIR"],  # 测试集GT图像目录
        config["TEST"]["DATASET"]["PAIRED_TEST_LR_IMAGES_DIR"]   # 测试集LR图像目录
    )
    # 创建数据加载器
    test_dataloader = DataLoader(
        test_datasets,
        batch_size=config["TEST"]["HYP"]["IMGS_PER_BATCH"],  # 批次大小
        shuffle=config["TEST"]["HYP"]["SHUFFLE"],  # 是否打乱数据
        num_workers=config["TEST"]["HYP"]["NUM_WORKERS"],  # 数据加载进程数
        pin_memory=config["TEST"]["HYP"]["PIN_MEMORY"],  # 是否固定内存(加速GPU传输)
        drop_last=False,  # 是否丢弃最后一个不完整的批次
        persistent_workers=config["TEST"]["HYP"]["PERSISTENT_WORKERS"]  # 是否保持工作进程存活
    )
    # 创建CUDA预取器,提前将数据加载到GPU
    test_test_data_prefetcher = CUDAPrefetcher(test_dataloader, device)

    return test_test_data_prefetcher


def build_model(config: Any, device: torch.device):
    """
    根据配置构建生成模型(超分辨率模型)
    
    参数:
        config: 配置信息
        device: 计算设备
    
    返回:
        构建好的生成模型(已移动到指定设备)
    """
    # 从model模块中根据配置的模型名称实例化模型
    g_model = model.__dict__[config["MODEL"]["G"]["NAME"]](
        in_channels=config["MODEL"]["G"]["IN_CHANNELS"],  # 输入图像通道数
        out_channels=config["MODEL"]["G"]["OUT_CHANNELS"],  # 输出图像通道数
        channels=config["MODEL"]["G"]["CHANNELS"],  # 模型中间层通道数
        num_rcb=config["MODEL"]["G"]["NUM_RCB"]  # RCB模块数量(残差连接块)
    )
    # 将模型移动到指定计算设备
    g_model = g_model.to(device)

    # 如果配置开启模型编译,使用torch.compile加速推理
    if config["MODEL"]["G"]["COMPILED"]:
        g_model = torch.compile(g_model)

    return g_model


def test(
        g_model: nn.Module,
        test_data_prefetcher: CUDAPrefetcher,
        psnr_model: nn.Module,
        ssim_model: nn.Module,
        device: torch.device,
        config: Any,
) -> [float, float]:
    """
    执行测试过程,计算模型在测试集上的PSNR和SSIM指标,并可选保存超分辨率结果
    
    参数:
        g_model: 生成模型(超分辨率模型)
        test_data_prefetcher: 测试数据预取器
        psnr_model: PSNR评估模型
        ssim_model: SSIM评估模型
        device: 计算设备
        config: 配置信息
    
    返回:
        平均PSNR和平均SSIM值
    """
    save_image = False  # 是否保存超分辨率图像的标志
    save_image_dir = ""  # 保存图像的目录

    # 如果配置中设置了保存图像目录,则开启保存功能并创建目录
    if config["TEST"]["SAVE_IMAGE_DIR"]:
        save_image = True
        # 构建保存目录路径(实验名称作为子目录)
        save_image_dir = os.path.join(config["TEST"]["SAVE_IMAGE_DIR"], config["EXP_NAME"])
        make_directory(save_image_dir)  # 创建目录(如果不存在)

    # 计算测试集的批次数量
    batches = len(test_data_prefetcher)
    # 设置进度条打印间隔(超过100批则每100批打印一次,否则每批都打印)
    if batches > 100:
        print_freq = 100
    else:
        print_freq = batches
    # 初始化指标记录器
    batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)  # 记录每批处理时间
    psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE)  # 记录PSNR值(平均)
    ssimes = AverageMeter("SSIM", ":4.4f", Summary.AVERAGE)  # 记录SSIM值(平均)
    # 初始化进度条
    progress = ProgressMeter(
        len(test_data_prefetcher),
        [batch_time, psnres, ssimes],
        prefix=f"Test: "
    )

    # 将模型设置为评估模式(关闭dropout等训练特定层)
    g_model.eval()

    # 关闭梯度计算(节省内存,加速推理)
    with torch.no_grad():
        batch_index = 0  # 批次索引初始化

        # 重置数据预取器指针并加载第一批数据
        test_data_prefetcher.reset()
        batch_data = test_data_prefetcher.next()

        # 记录当前批次开始时间
        end = time.time()

        # 循环处理所有批次数据
        while batch_data is not None:
            # 加载批次数据并移动到指定设备
            gt = batch_data["gt"].to(device, non_blocking=True)  # 高分辨率参考图像
            lr = batch_data["lr"].to(device, non_blocking=True)  # 低分辨率输入图像

            # 模型推理:生成超分辨率图像
            sr = g_model(lr)

            # 计算图像质量评估指标
            psnr = psnr_model(sr, gt)  # 计算PSNR
            ssim = ssim_model(sr, gt)  # 计算SSIM

            # 更新指标记录器(累计平均值)
            psnres.update(psnr.item(), sr.size(0))  # sr.size(0)为当前批次样本数
            ssimes.update(ssim.item(), ssim.size(0))

            # 更新批次处理时间记录
            batch_time.update(time.time() - end)
            end = time.time()

            # 按间隔打印进度信息
            if batch_index % print_freq == 0:
                progress.display(batch_index)

            # 检查图像名称是否存在
            if batch_data["image_name"] == "":
                raise ValueError("图像名称为空,请检查数据集。")
            # 如果开启保存图像功能,则保存超分辨率结果
            if save_image:
                image_name = os.path.basename(batch_data["image_name"][0])  # 获取图像文件名
                sr_image = tensor_to_image(sr, False, False)  # 将张量转换为图像格式
                sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)  # 转换颜色通道(RGB->BGR,适应OpenCV)
                cv2.imwrite(os.path.join(save_image_dir, image_name), sr_image)  # 保存图像

            # 预加载下一批数据
            batch_data = test_data_prefetcher.next()

            # 批次索引加1
            batch_index += 1

    # 打印测试集整体性能指标汇总
    progress.display_summary()

    # 返回平均PSNR和SSIM
    return psnres.avg, ssimes.avg


def main() -> None:
    """主函数:解析参数、加载配置、准备测试环境并执行测试"""
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path",
                        type=str,
                        default="./configs/test/SRGAN_x4-SRGAN_ImageNet-Set5.yaml",
                        required=True,
                        help="测试配置文件的路径。")
    args = parser.parse_args()  # 解析参数

    # 读取YAML配置文件
    with open(args.config_path, "r") as f:
        config = yaml.full_load(f)

    # 设置计算设备(GPU,指定设备ID)
    device = torch.device("cuda", config["DEVICE_ID"])
    # 加载测试数据集
    test_data_prefetcher = load_dataset(config, device)
    # 构建生成模型
    g_model = build_model(config, device)
    # 构建PSNR和SSIM评估模型
    psnr_model, ssim_model = build_iqa_model(
        config["SCALE"],  # 超分倍数
        config["TEST"]["ONLY_TEST_Y_CHANNEL"],  # 是否只在Y通道评估(针对YCrCb格式)
        device,
    )

    # 加载预训练模型权重
    g_model = load_pretrained_state_dict(g_model, config["MODEL"]["G"]["COMPILED"], config["MODEL_WEIGHTS_PATH"])

    # 执行测试
    test(g_model,
         test_data_prefetcher,
         psnr_model,
         ssim_model,
         device,
         config)


# 程序入口:当脚本被直接运行时执行main函数
if __name__ == "__main__":
    main()
posted @ 2025-09-25 21:25  SaTsuki26681534  阅读(17)  评论(0)    收藏  举报