Loading

【CV】GAN代码解析 image_folder.py

"""A modified image folder class

We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""

import torch.utils.data as data
# 导入 torch.utils.data 并起别名 data,以便继承 data.Dataset 实现自定义数据集
from pathlib import Path
# 从标准库导入 Path,用于跨平台、易用的路径处理
from PIL import Image

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
    '.tif', '.TIF', '.tiff', '.TIFF',
]
# 定义可接受的图片扩展名列表(大小写都覆盖),后面用来做文件类型过滤


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
# 判断 filename 是否以任一允许的扩展名结尾;返回布尔值

def make_dataset(dir, max_dataset_size=float("inf")):
# 定义函数 make_dataset(dir, max_dataset_size=float("inf")),
# 用于从目录递归收集图片路径,并支持数量上限
    images = []
    dir_path = Path(dir)
# 把传入的 dir 包装成 Path 对象 dir_path
# 用人话说就是用dir变量创建一个Path类的对象
    assert dir_path.is_dir(), f'{dir} is not a valid directory'
    # 断言 dir_path 必须是一个目录,否则抛错

    for path in sorted(dir_path.rglob('*')):
        # 使用 dir_path.rglob('*') 递归遍历该目录及子目录的所有路径,
        # 并先按字典序 sorted 固定遍历顺序(保证复现性)
        # 人话:遍历dir_path下的所有文件,并判断每一个文件是不是图像文件
        # 如果是图像的话就把对应图像路径加到images列表里
        if path.is_file() and is_image_file(path.name):
            images.append(str(path))
    return images[:min(max_dataset_size, len(images))]


def default_loader(path):
    # 定义 default_loader(path)。
    # 用 Pillow 打开图片并统一 convert('RGB'),确保三通道(便于后续变换与模型输入统一)
    # 从str格式的图像路径中load RGB格式的图像对象
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):
# 定义数据集类 ImageFolder,继承自 data.Dataset
    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        # 从数据集根目录root里收集图像文件,并存在imgs里
        if len(imgs) == 0: # 如果根目录里没有图像,则抛出异常
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        # return_paths 表示 __getitem__ 是否连同路径一起返回
        self.loader = loader
        # loader 为读图函数默认用 default_loader

    def __getitem__(self, index):
        # 根据index随机获取图像文件(在获取之前还需要进行对应的预处理)
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        # 若设置了 transform,则对图像执行变换(例如 ToTensor/Normalize/Resize 等)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)
    # 返回数据集中图片的数量 len(self.imgs)

posted @ 2025-09-24 16:07  SaTsuki26681534  阅读(11)  评论(0)    收藏  举报