"""This module contains simple helper functions"""
# ↑ 模块文档字符串:本文件包含一些简单的辅助函数(图像/张量转换、保存、目录创建等)
from __future__ import print_function # 兼容旧版 Python 的 print 函数语法(现代环境一般无须此行)
import torch # PyTorch 主库
import numpy as np # 数值计算库
from PIL import Image # Pillow:图像读写/保存
from pathlib import Path # 跨平台路径处理
def tensor2im(input_image, imtype=np.uint8):
# 把张量转换为图像
""" "Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
# 作用:把张量(C,H,W,且通常值域在[-1,1])转换为 numpy 图像数组(H,W,C,uint8 0~255)
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # 若是 torch.Tensor,先取其 .data(历史写法,等价于 tensor 本身)
image_tensor = input_image.data
else:
return input_image # 非 numpy 且非 tensor(如已是图像数组等),直接原样返回
# 将 tensor 的第 0 张(batch 维)拿出来,搬到 CPU,转为 float,再转为 numpy
image_numpy = image_tensor[0].cpu().float().numpy() # 形状此时为 (C,H,W)
if image_numpy.shape[0] == 1: # 若是单通道灰度图,扩展为 3 通道(RGB)
image_numpy = np.tile(image_numpy, (3, 1, 1))
# 通道维移到最后,并把值域从[-1,1]线性映射到[0,255]
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # (H,W,C)
else: # 若本来就是 numpy 数组,直接返回(不做处理)
image_numpy = input_image
return image_numpy.astype(imtype) # 转成目标 dtype(默认 uint8)
def diagnose_network(net, name="network"):
# 诊断网络状态
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
# 作用:统计并打印“各参数梯度绝对值的均值”的平均值(用于诊断梯度是否消失/爆炸)
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None: # 只统计有梯度的参数
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count # 求所有参数均值的平均
print(name)
print(mean) # 输出形如:network\n tensor(0.xxx, device=..., dtype=...)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
# 作用:把 numpy 图像保存到磁盘,可按给定纵横比做一次简单的缩放
image_pil = Image.fromarray(image_numpy) # numpy(H,W,C)->PIL Image
h, w, _ = image_numpy.shape # 注意:h=高度,w=宽度
# 注意:PIL.Image.resize 的 size 参数是 (width, height)
# 下面代码使用 (h, ...)、( ..., w),顺序看起来“对调”,此实现沿用自原始仓库
if aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
if aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path) # 保存到目标路径
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
# 作用:打印 numpy 数组的统计信息(形状/均值/极值/中位数/标准差)
x = x.astype(np.float64) # 先转成 float64,避免整数整除等精度问题
if shp:
print("shape,", x.shape)
if val:
x = x.flatten()
print(
"mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f"
% (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))
)
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
# 作用:批量创建目录(若不存在则创建)
# 这里用 "isinstance(paths, list) and not isinstance(paths, str)" 来避免把字符串当作可迭代列表逐字符处理
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path) # 逐个创建
else:
mkdir(paths) # 单一路径
def mkdir(path):
"""create a single empty directory