Loading

【cv】GAN代码解析 base_dataset.py

"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.

It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
# 导入依赖:random、numpy、torch.utils.data(起别名 data)、
# PIL.Image、torchvision.transforms(起别名 transforms)、
# 以及抽象基类机制 ABC/abstractmethod


class BaseDataset(data.Dataset, ABC):
    # 定义类 BaseDataset(data.Dataset, ABC),继承 PyTorch 的 Dataset 与 ABC,用于约束数据集子类接口

    """
    To create a subclass, you need to implement the following four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
    """

    def __init__(self, opt):
        # 接受命令行对象
        self.opt = opt
        self.root = opt.dataroot

    @staticmethod
    def modify_commandline_options(parser, is_train):
        # 为数据集添加/改写命令行参数的钩子。默认直接返回 parser(即不做改动);子类可重写该方法定制选项
        return parser

    @abstractmethod
    def __len__(self):
        # 返回数据集大小的抽象方法,这里仅做占位(return 0);子类必须实现
        return 0

    @abstractmethod
    def __getitem__(self, index):
        # 返回一条数据及其元信息的抽象方法,子类必须实现。
        # 通过index索引指定数据的方法
        pass


def get_params(opt, size):
    # 在这个方法里获取crop_resize的参数
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        new_h = new_w = opt.load_size
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w
    # 取原图尺寸 (w, h),根据 opt.preprocess 计算预调整尺寸 (new_w, new_h)

    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    # 这是一个变量赋值语句,将右边表达式的结果赋值给变量x
    # random.randint(a, b)函数调用:
    # 该函数来自 Python 标准库random,用于生成一个闭区间[a, b]内的随机整数
    # (a和b必须是整数,且a <= b,否则会报错)
    # 在这里就是生成一个[0,np.maximum(0, new_w - opt.crop_size)]范围内的随机整数
    # np.maximum是 NumPy 库的函数,用于计算两个参数中的最大值(此处比较0和new_w - opt.crop_size)。
    # 这样写的目的是确保第二个参数不会小于 0

    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
    # 在新尺寸上随机采样裁剪起点 (x, y),范围由 new_w/new_h 与 crop_size 决定;

    flip = random.random() > 0.5
    # 随机采样水平翻转标志 flip(50% 概率)

    return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params=None, grayscale=False,
                  method=transforms.InterpolationMode.BICUBIC, convert=True):
    # 构造 torchvision 变换序列
    # 用人话说就是根据opt参数来确定需要对张量进行哪些操作,并加到transform_list序列里去
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    # 初始化 transform_list=[];若 grayscale=True(灰度图),加 transforms.Grayscale(1)(单通道)

    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    # 注意这里没加lambda是因为本来就已经有了resize这个函数,加lambda是为了定义新的匿名函数
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
    # Lambda类与匿名函数:transforms.Lambda(...) 是一个接受函数作为参数的类,
    # 这里传入的是一个lambda 函数(Python 中的匿名函数)。
    # lambda 函数的语法是lambda 参数: 表达式,用于定义简单的单行函数,无需显式命名。
    # 这里的lambda img: __scale_width(...)定义了一个接受img参数的匿名函数,
    # 函数体是调用__scale_width函数,传入img、opt.load_size、opt.crop_size、method四个参数,并返回其结果。

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
    # 若 opt.preprocess=='none':使用 __make_power_2 将尺寸调整为最接近的 4 的倍数,
    # 避免下采样层对非对齐尺寸的限制(常见于某些卷积架构)

    if not opt.no_flip:
    # 水平翻转
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    # 张量化与归一化

    return transforms.Compose(transform_list)


def __transforms2pil_resize(method):
# 把 torchvision.transforms.InterpolationMode 映射到 PIL 的插值常量
    mapper = {transforms.InterpolationMode.BILINEAR: Image.BILINEAR,
              transforms.InterpolationMode.BICUBIC: Image.BICUBIC,
              transforms.InterpolationMode.NEAREST: Image.NEAREST,
              transforms.InterpolationMode.LANCZOS: Image.LANCZOS,}
    # 函数内部定义了变量mapper,其值是一个字典(用{}表示)。字典中包含 4 组键值对,
    # 格式为key: value,每组之间用逗号分隔:
    # 键(key)是transforms.InterpolationMode类的属性(可能是枚举成员,如BILINEAR、BICUBIC等)
    # 值(value)是Image类的属性(可能是 PIL 库中定义的插值模式常量,如Image.BILINEAR等)
    return mapper[method]
# 最后通过return mapper[method]语句返回结果:
# mapper[method]表示通过键method从字典mapper中获取对应的值(即根据输入的method参数,返回对应的 PIL 插值常量)


def __make_power_2(img, base, method=transforms.InterpolationMode.BICUBIC):
    # 这个方法的作用应该是处理图像的尺寸,以符合PIL处理的要求
    method = __transforms2pil_resize(method)
    # 用上面定义的函数,将插值 method 先映射到 PIL 常量;
    ow, oh = img.size
    # 读取原尺寸(ow, oh)
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if h == oh and w == ow:
        return img
    # 计算四舍五入到 base(默认 4)的倍数后的 (w, h);若尺寸已满足则原样返回;

    __print_size_warning(ow, oh, w, h)
    # 否则打印一次性警告并 resize
    return img.resize((w, h), method)


def __scale_width(img, target_size, crop_size, method=transforms.InterpolationMode.BICUBIC):
    method = __transforms2pil_resize(method)
    ow, oh = img.size
    if ow == target_size and oh >= crop_size:
        return img
    w = target_size
    h = int(max(target_size * oh / ow, crop_size))
    return img.resize((w, h), method)


def __crop(img, pos, size):
    # 定义crop操作的具体操作
    # 若原宽已等于 target_size 且原高≥crop_size,直接返回;
    # 否则将宽缩放到 target_size,高按比例放缩,同时保证高≥crop_size;最后 resize
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img


def __flip(img, flip):
    # flip操作的具体操作
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


def __print_size_warning(ow, oh, w, h):
    # 定义上面用到的打印警告的函数
    # 仅首次打印“尺寸需为 4 的倍数,已从 (ow,oh) 调整到 (w,h)”的警告;使用函数属性 has_printed 防止重复输出
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True

posted @ 2025-09-24 14:56  SaTsuki26681534  阅读(12)  评论(0)    收藏  举报