Loading

【CVCVCV】GAN代码解析-数据集处理 __init__.py

0

通过读GAN的代码了解一下常规深度学习模型的整体工程结构

1

"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
# 定义一些根据相关参数加载数据集的类和方法

import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.

    这个函数会导入 data/<dataset_name>_dataset.py,
    并在其中找到一个名为 DatasetNameDataset 的类(大小写不敏感),且该类必须继承 BaseDataset
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    # 构造模块路径字符串,例如 dataset_name='dummy' → "data.dummy_dataset
    datasetlib = importlib.import_module(dataset_filename)
    # 动态导入上面拼出的模块对象(等价于 import data.dummy_dataset as datasetlib 这种效果)
    # 导入构造数据集需要用到的py文件

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    # 预置 dataset=None 以便后面找到类后再赋值。
    # 生成“目标类名”的匹配基准:把 dataset_name 里的下划线都去掉,再加上 'dataset'
    # 可以理解为现在需要在datasetlib这个Py文件里找到目标的数据集类,这里就是为了把dataset_name转化为数据集类的名字

    for name, cls in datasetlib.__dict__.items():
    # __dict__.items():获取键值对视图
    # 字典的 items() 方法会返回一个键值对视图对象(view object),其中每个元素是一个元组 (key, value),分别对应字典中的键和值。对于 datasetlib.__dict__.items(),其返回的每个元组格式为 (属性名称, 属性对象)(例如 ('DummyDataset', <class 'data.dummy_dataset.DummyDataset'>))。
    #
    # for name, cls in ...:遍历键值对
    # 这是一个 for 循环,用于遍历 items() 返回的键值对视图:
    # name 被赋值为每个元组中的第一个元素(即属性名称,字符串类型,如 'DummyDataset');
    # cls 被赋值为每个元组中的第二个元素(即属性对象,可能是类、函数、变量等,这里主要关注类对象)

        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            # Python 的“显式换行符”(line continuation)。一行末尾放 \ 表示把下一行和当前行当成同一条语句来解析
            dataset = cls
    # 遍历被导入模块的符号表(名字到对象的映射)。
    # 条件:

    # name.lower() == target_dataset_name.lower():名称大小写不敏感精确匹配前面生成的目标类名。
    # issubclass(cls, BaseDataset):同时要求它是 BaseDataset 的子类。
    #
    # 一旦匹配成功,就把这个类对象记录到 dataset
    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options
    # 作用:根据名字找到数据集“类”,然后返回它的静态方法 modify_commandline_options。(用于读取数据集相关参数或者类似操作)
    # 用途:让外部(如命令行解析逻辑)在知道数据集类型后,先让该数据集类有机会“增补/修改”命令行参数的默认值与说明


def create_dataset(opt):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    # customDatasetDataloader在下面有定义
    dataset = data_loader.load_data()
    return dataset
# 实例化数据加载器包装类 CustomDatasetDataLoader(opt);
# 调用其 load_data() 得到“可迭代的数据对象”(这个实现中 load_data() 直接返回 self,
# 所以 dataset 实际上就是数据加载器包装自身);
# 返回这个对象(可在训练/测试循环中用于迭代批次)


class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""
# 这个类包装了具体数据集类,并利用 torch.utils.data.DataLoader 实现多线程/多进程的数据加载

    def __init__(self, opt):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        先创建数据集实例,再创建多线程数据加载器。
        """
        self.opt = opt
        # 保存配置对象 opt(通常是命令行解析得到的),后面会使用其中的字段。
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        # 调用刚才的find_dataset_using_name方法,根据 opt.dataset_mode(如 "dummy")找到对应的数据集“类”。
        # parser.add_argument('--dataset_mode', type=str, default='unaligned',
        # help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
        # 上面是dataset_mode的参数
        self.dataset = dataset_class(opt)
        #实例化数据集,把同一个 opt 传入数据集构造函数(你的自定义数据集类在 __init__ 里可读取各种参数)
        print("dataset [%s] was created" % type(self.dataset).__name__)
        # parser.add_argument('--name', type=str, default='experiment_name',
        # help='name of the experiment. It decides where to store samples and models')
        # 创建的数据集类的名字和experiment_name相同
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads))
        # 创建 PyTorch 的 DataLoader:
        # self.dataset:上面创建的 Dataset 实例(需实现 __len__ 和 __getitem__)。
        #   batch_size=opt.batch_size:批大小来自配置。
        #   shuffle=not opt.serial_batches:如果 不是 串行模式,则打乱(通常训练集打乱,验证/测试集串行)。
        #   num_workers=int(opt.num_threads):数据加载的工作进程数(多进程/多线程,具体由后端与平台决定;数值来自配置)
    def load_data(self):
        return self
    # 返回自身,使外部拿到的对象本身就是一个“可用于迭代批数据”的包装器

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.opt.max_dataset_size)
    # 返回可用数据条目数的上限:取实际数据集长度与 opt.max_dataset_size 的较小值。
    # 这样可以通过配置限制“最多使用多少样本”,便于快速试验或节省时间

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data
    # 使这个包装器成为可迭代对象(for batch in dataset: 这种用法)。

    # dataIter 通常是 data iterator(数据迭代器) 的缩写,
    # 指用于按批次遍历数据的迭代器对象。
    # 它的核心作用是允许程序通过循环逐步获取数据(尤其是大规模数据集),
    # 而无需一次性加载全部数据到内存

    # 逐批从 self.dataloader 里取数据。
    # 根据批索引 i 与 batch_size 计算已取样本数量,一旦达到/超过 max_dataset_size 便提前 break。

    # dataIter:即数据迭代器(在代码中对应CustomDatasetDataLoader的实例),是遍历数据的接口,负责按批次输出数据(通过__iter__方法)。
    # batchSize:每个批次包含的样本数量(对应代码中的opt.batch_size),决定了每次迭代返回的数据量。
    # NumWorkers:数据加载时使用的子线程数量(对应代码中的opt.num_threads),用于并行加载数据以提高效率
    # 初始化时,dataIter根据batchSize和NumWorkers配置self.dataloader;

    # 迭代时,dataIter通过__iter__方法将self.dataloader按batchSize生成的批次数据返回给用户
    # (如训练代码中的for batch_data in dataset循环)

    # yield data 把当前批返回给外部训练/测试循环
posted @ 2025-09-23 16:06  SaTsuki26681534  阅读(12)  评论(0)    收藏  举报