dataset类

from torch.utils.data import Dataset  # 导入PyTorch的Dataset基类,自定义数据集必须继承它
from PIL import Image  # 导入PIL库的Image模块,用于读取和处理图像文件
import os  # 导入os库,用于处理文件路径、目录操作等系统相关功能class MyData(Dataset)
:  # 定义MyData类,继承自PyTorch的Dataset抽象类
    # 这是数据集的"构造方法",用于初始化数据集的基本信息(如路径、文件列表等)
    def __init__(self, root_dir, label_dir):  # root_dir:根目录路径;label_dir:标签目录路径
        # 保存根目录路径到实例变量(方便类内部其他方法调用)
        self.root_dir = root_dir  
        # 保存标签目录路径到实例变量(标签目录名称通常就是该类别的标签,如"ants"、"bees")
        self.label_dir = label_dir  
        # 拼接根目录和标签目录,得到图像文件所在的完整目录路径(例如"dataset/train/ants")
        self.path = os.path.join(self.root_dir, self.label_dir)  
        # 获取该目录下所有文件的名称列表(例如["0013035.jpg", "003454.jpg"...])
        self.img_path = os.listdir(self.path)  


    # 这是数据集的核心方法,用于根据索引idx获取一个样本(图像+标签)
    def __getitem__(self, idx):  # idx:样本的索引(从0开始)
        # 根据索引idx从图像名称列表中取出对应的图像文件名(例如第0个是"0013035.jpg")
        img_name = self.img_path[idx]  
        # 拼接根目录、标签目录、图像文件名,得到该图像的完整路径(例如"dataset/train/ants/0013035.jpg")
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  
        # 用PIL的Image.open()方法打开图像文件,得到图像对象(可后续转为PyTorch张量)
        img = Image.open(img_item_path)  
        # 将标签目录名称作为该图像的标签(例如"ants"目录下的图像标签就是"ants")
        label = self.label_dir  
        # 返回该索引对应的图像和标签(这是PyTorch要求的格式:(数据, 标签))
        return img, label  


    # 这是数据集的长度方法,返回数据集中样本的总数量
    def __len__(self):  
        # 图像名称列表的长度就是样本数量(因为每个文件名对应一个图像)
        return len(self.img_path)  


### 3. 实例化数据集并组合
```python
root_dir = "dataset/train"  # 定义根目录路径(存放训练集的总目录)
ant_label_dir = "ants"  # 定义"蚂蚁"类别的标签目录名称
bee_label_dir = "bees"  # 定义"蜜蜂"类别的标签目录名称

# 实例化"蚂蚁"数据集:传入根目录和蚂蚁标签目录,得到只包含蚂蚁图像的数据集
antset = MyData(root_dir, ant_label_dir)  
# 实例化"蜜蜂"数据集:传入根目录和蜜蜂标签目录,得到只包含蜜蜂图像的数据集
beeset = MyData(root_dir, bee_label_dir)  

# 将蚂蚁数据集和蜜蜂数据集合并,得到完整的训练集(PyTorch的Dataset支持用"+"拼接)
trainset = antset + beeset

以上代码可通过pycharm的代码逐步调试,并实时查看代码运行是否无误

posted @ 2025-10-12 11:05  病友白某  阅读(13)  评论(0)    收藏  举报