dataset类
from torch.utils.data import Dataset # 导入PyTorch的Dataset基类,自定义数据集必须继承它
from PIL import Image # 导入PIL库的Image模块,用于读取和处理图像文件
import os # 导入os库,用于处理文件路径、目录操作等系统相关功能class MyData(Dataset)
: # 定义MyData类,继承自PyTorch的Dataset抽象类
# 这是数据集的"构造方法",用于初始化数据集的基本信息(如路径、文件列表等)
def __init__(self, root_dir, label_dir): # root_dir:根目录路径;label_dir:标签目录路径
# 保存根目录路径到实例变量(方便类内部其他方法调用)
self.root_dir = root_dir
# 保存标签目录路径到实例变量(标签目录名称通常就是该类别的标签,如"ants"、"bees")
self.label_dir = label_dir
# 拼接根目录和标签目录,得到图像文件所在的完整目录路径(例如"dataset/train/ants")
self.path = os.path.join(self.root_dir, self.label_dir)
# 获取该目录下所有文件的名称列表(例如["0013035.jpg", "003454.jpg"...])
self.img_path = os.listdir(self.path)
# 这是数据集的核心方法,用于根据索引idx获取一个样本(图像+标签)
def __getitem__(self, idx): # idx:样本的索引(从0开始)
# 根据索引idx从图像名称列表中取出对应的图像文件名(例如第0个是"0013035.jpg")
img_name = self.img_path[idx]
# 拼接根目录、标签目录、图像文件名,得到该图像的完整路径(例如"dataset/train/ants/0013035.jpg")
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 用PIL的Image.open()方法打开图像文件,得到图像对象(可后续转为PyTorch张量)
img = Image.open(img_item_path)
# 将标签目录名称作为该图像的标签(例如"ants"目录下的图像标签就是"ants")
label = self.label_dir
# 返回该索引对应的图像和标签(这是PyTorch要求的格式:(数据, 标签))
return img, label
# 这是数据集的长度方法,返回数据集中样本的总数量
def __len__(self):
# 图像名称列表的长度就是样本数量(因为每个文件名对应一个图像)
return len(self.img_path)
### 3. 实例化数据集并组合
```python
root_dir = "dataset/train" # 定义根目录路径(存放训练集的总目录)
ant_label_dir = "ants" # 定义"蚂蚁"类别的标签目录名称
bee_label_dir = "bees" # 定义"蜜蜂"类别的标签目录名称
# 实例化"蚂蚁"数据集:传入根目录和蚂蚁标签目录,得到只包含蚂蚁图像的数据集
antset = MyData(root_dir, ant_label_dir)
# 实例化"蜜蜂"数据集:传入根目录和蜜蜂标签目录,得到只包含蜜蜂图像的数据集
beeset = MyData(root_dir, bee_label_dir)
# 将蚂蚁数据集和蜜蜂数据集合并,得到完整的训练集(PyTorch的Dataset支持用"+"拼接)
trainset = antset + beeset
以上代码可通过pycharm的代码逐步调试,并实时查看代码运行是否无误
浙公网安备 33010602011771号