# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# 版权所有 2022 Dakewe Biotech Corporation。保留所有权利。
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据Apache License 2.0许可证授权;
# you may not use this file except in compliance with the License.
# 除非符合许可证条款,否则不得使用本文件。
# You may obtain a copy of the License at
# 您可以在以下位置获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# 除非法律要求或书面同意,软件
# distributed under the License is distributed on an "AS IS" BASIS,
# 根据许可证分发的软件按"原样"提供,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不提供任何明示或暗示的担保或条件。
# See the License for the specific language governing permissions and
# 有关权限和限制的具体规定,请参见许可证。
# limitations under the License.
# ==============================================================================
# 导入必要的库
import argparse # 用于解析命令行参数
import os # 用于文件路径和目录操作
import time # 用于计时
from typing import Any # 用于类型提示
import cv2 # OpenCV库,用于图像处理
import torch # PyTorch深度学习框架
import yaml # 用于读取YAML配置文件
from torch import nn # PyTorch的神经网络模块
from torch.utils.data import DataLoader # 用于数据加载的工具
import model # 导入自定义的模型模块
from dataset import CUDAPrefetcher, PairedImageDataset # 导入自定义数据集和数据预取器
from imgproc import tensor_to_image # 导入将张量转换为图像的工具函数
from utils import (build_iqa_model, # 导入构建图像质量评估模型的工具
load_pretrained_state_dict, # 导入加载预训练模型权重的工具
make_directory, # 导入创建目录的工具
AverageMeter, # 导入用于计算平均值的工具类
ProgressMeter, # 导入进度条工具类
Summary) # 导入用于指标汇总的工具类
def load_dataset(config: Any, device: torch.device) -> CUDAPrefetcher:
"""
加载测试数据集并返回CUDA预取器(加速数据加载)
参数:
config: 配置信息(从YAML文件读取)
device: 计算设备(CPU或GPU)
返回:
CUDAPrefetcher: 用于加速数据加载的预取器
"""
# 实例化成对图像数据集(包含高分辨率GT图像和低分辨率LR图像)
test_datasets = PairedImageDataset(
config["TEST"]["DATASET"]["PAIRED_TEST_GT_IMAGES_DIR"], # 测试集GT图像目录
config["TEST"]["DATASET"]["PAIRED_TEST_LR_IMAGES_DIR"] # 测试集LR图像目录
)
# 创建数据加载器
test_dataloader = DataLoader(
test_datasets,
batch_size=config["TEST"]["HYP"]["IMGS_PER_BATCH"], # 批次大小
shuffle=config["TEST"]["HYP"]["SHUFFLE"], # 是否打乱数据
num_workers=config["TEST"]["HYP"]["NUM_WORKERS"], # 数据加载进程数
pin_memory=config["TEST"]["HYP"]["PIN_MEMORY"], # 是否固定内存(加速GPU传输)
drop_last=False, # 是否丢弃最后一个不完整的批次
persistent_workers=config["TEST"]["HYP"]["PERSISTENT_WORKERS"] # 是否保持工作进程存活
)
# 创建CUDA预取器,提前将数据加载到GPU
test_test_data_prefetcher = CUDAPrefetcher(test_dataloader, device)
return test_test_data_prefetcher
def build_model(config: Any, device: torch.device):
"""
根据配置构建生成模型(超分辨率模型)
参数:
config: 配置信息
device: 计算设备
返回:
构建好的生成模型(已移动到指定设备)
"""
# 从model模块中根据配置的模型名称实例化模型
g_model = model.__dict__[config["MODEL"]["G"]["NAME"]](
in_channels=config["MODEL"]["G"]["IN_CHANNELS"], # 输入图像通道数
out_channels=config["MODEL"]["G"]["OUT_CHANNELS"], # 输出图像通道数
channels=config["MODEL"]["G"]["CHANNELS"], # 模型中间层通道数
num_rcb=config["MODEL"]["G"]["NUM_RCB"] # RCB模块数量(残差连接块)
)
# 将模型移动到指定计算设备
g_model = g_model.to(device)
# 如果配置开启模型编译,使用torch.compile加速推理
if config["MODEL"]["G"]["COMPILED"]:
g_model = torch.compile(g_model)
return g_model
def test(
g_model: nn.Module,
test_data_prefetcher: CUDAPrefetcher,
psnr_model: nn.Module,
ssim_model: nn.Module,
device: torch.device,
config: Any,
) -> [float, float]:
"""
执行测试过程,计算模型在测试集上的PSNR和SSIM指标,并可选保存超分辨率结果
参数:
g_model: 生成模型(超分辨率模型)
test_data_prefetcher: 测试数据预取器
psnr_model: PSNR评估模型
ssim_model: SSIM评估模型
device: 计算设备
config: 配置信息
返回:
平均PSNR和平均SSIM值
"""
save_image = False # 是否保存超分辨率图像的标志
save_image_dir = "" # 保存图像的目录
# 如果配置中设置了保存图像目录,则开启保存功能并创建目录
if config["TEST"]["SAVE_IMAGE_DIR"]:
save_image = True
# 构建保存目录路径(实验名称作为子目录)
save_image_dir = os.path.join(config["TEST"]["SAVE_IMAGE_DIR"], config["EXP_NAME"])
make_directory(save_image_dir) # 创建目录(如果不存在)
# 计算测试集的批次数量
batches = len(test_data_prefetcher)
# 设置进度条打印间隔(超过100批则每100批打印一次,否则每批都打印)
if batches > 100:
print_freq = 100
else:
print_freq = batches
# 初始化指标记录器
batch_time = AverageMeter("Time", ":6.3f", Summary.NONE) # 记录每批处理时间
psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE) # 记录PSNR值(平均)
ssimes = AverageMeter("SSIM", ":4.4f", Summary.AVERAGE) # 记录SSIM值(平均)
# 初始化进度条
progress = ProgressMeter(
len(test_data_prefetcher),
[batch_time, psnres, ssimes],
prefix=f"Test: "
)
# 将模型设置为评估模式(关闭dropout等训练特定层)
g_model.eval()
# 关闭梯度计算(节省内存,加速推理)
with torch.no_grad():
batch_index = 0 # 批次索引初始化
# 重置数据预取器指针并加载第一批数据
test_data_prefetcher.reset()
batch_data = test_data_prefetcher.next()
# 记录当前批次开始时间
end = time.time()
# 循环处理所有批次数据
while batch_data is not None:
# 加载批次数据并移动到指定设备
gt = batch_data["gt"].to(device, non_blocking=True) # 高分辨率参考图像
lr = batch_data["lr"].to(device, non_blocking=True) # 低分辨率输入图像
# 模型推理:生成超分辨率图像
sr = g_model(lr)
# 计算图像质量评估指标
psnr = psnr_model(sr, gt) # 计算PSNR
ssim = ssim_model(sr, gt) # 计算SSIM
# 更新指标记录器(累计平均值)
psnres.update(psnr.item(), sr.size(0)) # sr.size(0)为当前批次样本数
ssimes.update(ssim.item(), ssim.size(0))
# 更新批次处理时间记录
batch_time.update(time.time() - end)
end = time.time()
# 按间隔打印进度信息
if batch_index % print_freq == 0:
progress.display(batch_index)
# 检查图像名称是否存在
if batch_data["image_name"] == "":
raise ValueError("图像名称为空,请检查数据集。")
# 如果开启保存图像功能,则保存超分辨率结果
if save_image:
image_name = os.path.basename(batch_data["image_name"][0]) # 获取图像文件名
sr_image = tensor_to_image(sr, False, False) # 将张量转换为图像格式
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR) # 转换颜色通道(RGB->BGR,适应OpenCV)
cv2.imwrite(os.path.join(save_image_dir, image_name), sr_image) # 保存图像
# 预加载下一批数据
batch_data = test_data_prefetcher.next()
# 批次索引加1
batch_index += 1
# 打印测试集整体性能指标汇总
progress.display_summary()
# 返回平均PSNR和SSIM
return psnres.avg, ssimes.avg
def main() -> None:
"""主函数:解析参数、加载配置、准备测试环境并执行测试"""
# 创建命令行参数解析器
parser = argparse.ArgumentParser()
parser.add_argument("--config_path",
type=str,
default="./configs/test/SRGAN_x4-SRGAN_ImageNet-Set5.yaml",
required=True,
help="测试配置文件的路径。")
args = parser.parse_args() # 解析参数
# 读取YAML配置文件
with open(args.config_path, "r") as f:
config = yaml.full_load(f)
# 设置计算设备(GPU,指定设备ID)
device = torch.device("cuda", config["DEVICE_ID"])
# 加载测试数据集
test_data_prefetcher = load_dataset(config, device)
# 构建生成模型
g_model = build_model(config, device)
# 构建PSNR和SSIM评估模型
psnr_model, ssim_model = build_iqa_model(
config["SCALE"], # 超分倍数
config["TEST"]["ONLY_TEST_Y_CHANNEL"], # 是否只在Y通道评估(针对YCrCb格式)
device,
)
# 加载预训练模型权重
g_model = load_pretrained_state_dict(g_model, config["MODEL"]["G"]["COMPILED"], config["MODEL_WEIGHTS_PATH"])
# 执行测试
test(g_model,
test_data_prefetcher,
psnr_model,
ssim_model,
device,
config)
# 程序入口:当脚本被直接运行时执行main函数
if __name__ == "__main__":
main()