Loading

【cv】cycleGAN代码解析:util包-util.py

"""This module contains simple helper functions"""
# ↑ 模块文档字符串:本文件包含一些简单的辅助函数(图像/张量转换、保存、目录创建等)

from __future__ import print_function  # 兼容旧版 Python 的 print 函数语法(现代环境一般无须此行)
import torch                           # PyTorch 主库
import numpy as np                     # 数值计算库
from PIL import Image                  # Pillow:图像读写/保存
from pathlib import Path               # 跨平台路径处理
def tensor2im(input_image, imtype=np.uint8):
# 把张量转换为图像
    """ "Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    # 作用:把张量(C,H,W,且通常值域在[-1,1])转换为 numpy 图像数组(H,W,C,uint8 0~255)
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # 若是 torch.Tensor,先取其 .data(历史写法,等价于 tensor 本身)
            image_tensor = input_image.data
        else:
            return input_image  # 非 numpy 且非 tensor(如已是图像数组等),直接原样返回
        # 将 tensor 的第 0 张(batch 维)拿出来,搬到 CPU,转为 float,再转为 numpy
        image_numpy = image_tensor[0].cpu().float().numpy()  # 形状此时为 (C,H,W)
        if image_numpy.shape[0] == 1:  # 若是单通道灰度图,扩展为 3 通道(RGB)
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        # 通道维移到最后,并把值域从[-1,1]线性映射到[0,255]
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # (H,W,C)
    else:  # 若本来就是 numpy 数组,直接返回(不做处理)
        image_numpy = input_image
    return image_numpy.astype(imtype)  # 转成目标 dtype(默认 uint8)


def diagnose_network(net, name="network"):
# 诊断网络状态
    """Calculate and print the mean of average absolute(gradients)

    Parameters:
        net (torch network) -- Torch network
        name (str) -- the name of the network
    """
    # 作用:统计并打印“各参数梯度绝对值的均值”的平均值(用于诊断梯度是否消失/爆炸)
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:  # 只统计有梯度的参数
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count  # 求所有参数均值的平均
    print(name)
    print(mean)  # 输出形如:network\n tensor(0.xxx, device=..., dtype=...)


def save_image(image_numpy, image_path, aspect_ratio=1.0):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """
    # 作用:把 numpy 图像保存到磁盘,可按给定纵横比做一次简单的缩放
    image_pil = Image.fromarray(image_numpy)  # numpy(H,W,C)->PIL Image
    h, w, _ = image_numpy.shape               # 注意:h=高度,w=宽度

    # 注意:PIL.Image.resize 的 size 参数是 (width, height)
    # 下面代码使用 (h, ...)、( ..., w),顺序看起来“对调”,此实现沿用自原始仓库
    if aspect_ratio > 1.0:
        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
    if aspect_ratio < 1.0:
        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
    image_pil.save(image_path)  # 保存到目标路径


def print_numpy(x, val=True, shp=False):
    """Print the mean, min, max, median, std, and size of a numpy array

    Parameters:
        val (bool) -- if print the values of the numpy array
        shp (bool) -- if print the shape of the numpy array
    """
    # 作用:打印 numpy 数组的统计信息(形状/均值/极值/中位数/标准差)
    x = x.astype(np.float64)  # 先转成 float64,避免整数整除等精度问题
    if shp:
        print("shape,", x.shape)
    if val:
        x = x.flatten()
        print(
            "mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f"
            % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))
        )


def mkdirs(paths):
    """create empty directories if they don't exist

    Parameters:
        paths (str list) -- a list of directory paths
    """
    # 作用:批量创建目录(若不存在则创建)
    # 这里用 "isinstance(paths, list) and not isinstance(paths, str)" 来避免把字符串当作可迭代列表逐字符处理
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)   # 逐个创建
    else:
        mkdir(paths)      # 单一路径


def mkdir(path):
    """create a single empty directory

posted @ 2025-09-25 15:24  SaTsuki26681534  阅读(13)  评论(0)    收藏  举报