【CVCVCV】GAN代码解析-数据集处理 __init__.py
0
通过读GAN的代码了解一下常规深度学习模型的整体工程结构
1
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
# 定义一些根据相关参数加载数据集的类和方法
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
这个函数会导入 data/<dataset_name>_dataset.py,
并在其中找到一个名为 DatasetNameDataset 的类(大小写不敏感),且该类必须继承 BaseDataset
"""
dataset_filename = "data." + dataset_name + "_dataset"
# 构造模块路径字符串,例如 dataset_name='dummy' → "data.dummy_dataset
datasetlib = importlib.import_module(dataset_filename)
# 动态导入上面拼出的模块对象(等价于 import data.dummy_dataset as datasetlib 这种效果)
# 导入构造数据集需要用到的py文件
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
# 预置 dataset=None 以便后面找到类后再赋值。
# 生成“目标类名”的匹配基准:把 dataset_name 里的下划线都去掉,再加上 'dataset'
# 可以理解为现在需要在datasetlib这个Py文件里找到目标的数据集类,这里就是为了把dataset_name转化为数据集类的名字
for name, cls in datasetlib.__dict__.items():
# __dict__.items():获取键值对视图
# 字典的 items() 方法会返回一个键值对视图对象(view object),其中每个元素是一个元组 (key, value),分别对应字典中的键和值。对于 datasetlib.__dict__.items(),其返回的每个元组格式为 (属性名称, 属性对象)(例如 ('DummyDataset', <class 'data.dummy_dataset.DummyDataset'>))。
#
# for name, cls in ...:遍历键值对
# 这是一个 for 循环,用于遍历 items() 返回的键值对视图:
# name 被赋值为每个元组中的第一个元素(即属性名称,字符串类型,如 'DummyDataset');
# cls 被赋值为每个元组中的第二个元素(即属性对象,可能是类、函数、变量等,这里主要关注类对象)
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
# Python 的“显式换行符”(line continuation)。一行末尾放 \ 表示把下一行和当前行当成同一条语句来解析
dataset = cls
# 遍历被导入模块的符号表(名字到对象的映射)。
# 条件:
# name.lower() == target_dataset_name.lower():名称大小写不敏感精确匹配前面生成的目标类名。
# issubclass(cls, BaseDataset):同时要求它是 BaseDataset 的子类。
#
# 一旦匹配成功,就把这个类对象记录到 dataset
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
# 作用:根据名字找到数据集“类”,然后返回它的静态方法 modify_commandline_options。(用于读取数据集相关参数或者类似操作)
# 用途:让外部(如命令行解析逻辑)在知道数据集类型后,先让该数据集类有机会“增补/修改”命令行参数的默认值与说明
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
# customDatasetDataloader在下面有定义
dataset = data_loader.load_data()
return dataset
# 实例化数据加载器包装类 CustomDatasetDataLoader(opt);
# 调用其 load_data() 得到“可迭代的数据对象”(这个实现中 load_data() 直接返回 self,
# 所以 dataset 实际上就是数据加载器包装自身);
# 返回这个对象(可在训练/测试循环中用于迭代批次)
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
# 这个类包装了具体数据集类,并利用 torch.utils.data.DataLoader 实现多线程/多进程的数据加载
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
先创建数据集实例,再创建多线程数据加载器。
"""
self.opt = opt
# 保存配置对象 opt(通常是命令行解析得到的),后面会使用其中的字段。
dataset_class = find_dataset_using_name(opt.dataset_mode)
# 调用刚才的find_dataset_using_name方法,根据 opt.dataset_mode(如 "dummy")找到对应的数据集“类”。
# parser.add_argument('--dataset_mode', type=str, default='unaligned',
# help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
# 上面是dataset_mode的参数
self.dataset = dataset_class(opt)
#实例化数据集,把同一个 opt 传入数据集构造函数(你的自定义数据集类在 __init__ 里可读取各种参数)
print("dataset [%s] was created" % type(self.dataset).__name__)
# parser.add_argument('--name', type=str, default='experiment_name',
# help='name of the experiment. It decides where to store samples and models')
# 创建的数据集类的名字和experiment_name相同
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
# 创建 PyTorch 的 DataLoader:
# self.dataset:上面创建的 Dataset 实例(需实现 __len__ 和 __getitem__)。
# batch_size=opt.batch_size:批大小来自配置。
# shuffle=not opt.serial_batches:如果 不是 串行模式,则打乱(通常训练集打乱,验证/测试集串行)。
# num_workers=int(opt.num_threads):数据加载的工作进程数(多进程/多线程,具体由后端与平台决定;数值来自配置)
def load_data(self):
return self
# 返回自身,使外部拿到的对象本身就是一个“可用于迭代批数据”的包装器
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
# 返回可用数据条目数的上限:取实际数据集长度与 opt.max_dataset_size 的较小值。
# 这样可以通过配置限制“最多使用多少样本”,便于快速试验或节省时间
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
# 使这个包装器成为可迭代对象(for batch in dataset: 这种用法)。
# dataIter 通常是 data iterator(数据迭代器) 的缩写,
# 指用于按批次遍历数据的迭代器对象。
# 它的核心作用是允许程序通过循环逐步获取数据(尤其是大规模数据集),
# 而无需一次性加载全部数据到内存
# 逐批从 self.dataloader 里取数据。
# 根据批索引 i 与 batch_size 计算已取样本数量,一旦达到/超过 max_dataset_size 便提前 break。
# dataIter:即数据迭代器(在代码中对应CustomDatasetDataLoader的实例),是遍历数据的接口,负责按批次输出数据(通过__iter__方法)。
# batchSize:每个批次包含的样本数量(对应代码中的opt.batch_size),决定了每次迭代返回的数据量。
# NumWorkers:数据加载时使用的子线程数量(对应代码中的opt.num_threads),用于并行加载数据以提高效率
# 初始化时,dataIter根据batchSize和NumWorkers配置self.dataloader;
# 迭代时,dataIter通过__iter__方法将self.dataloader按batchSize生成的批次数据返回给用户
# (如训练代码中的for batch_data in dataset循环)
# yield data 把当前批返回给外部训练/测试循环

浙公网安备 33010602011771号