pytorch加载数据集

pytorch加载数据集

在 PyTorch 中,加载数据集主要依赖于两个核心工具:torch.utils.data.Dataset(用于定义数据集结构)和torch.utils.data.DataLoader(用于批量加载数据)。下面分内置数据集自定义数据集两种场景详细说明。

一、核心组件说明

  • Dataset:抽象类,用于定义数据集的基本结构(存储数据和标签),需要重写__len__(返回数据集大小)和__getitem__(根据索引返回单条数据和标签)。
  • DataLoader:基于Dataset的包装器,用于批量加载数据,支持自动打乱顺序、多进程并行加载、设置批次大小等功能。

二、加载内置数据集(以torchvision为例)

PyTorch 的torchvision.datasets模块内置了常见数据集(如 MNIST、CIFAR10、ImageNet 等),可直接调用,无需手动处理数据文件。

步骤:

  1. 导入必要库;
  2. 定义数据预处理(transform);
  3. 实例化内置数据集
  4. 批量加载数据集。

例:加载torchvision下的数据集MNIST手写数字数据集(0-9)

import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms#transforms预处理工具
transform=transforms.Compose(
    [
        transforms.ToTensor(),#转张量,归一化
#Normalize将经过ToTensor转换后的图像数据调整为 “均值接近 0、标准差接近 1” 的分布
        transforms.Normalize((0.1307,), (0.3081,))#(0.1307,)均值, (0.3081,)标准差
    ]#预处理
)#流水线工作Compose多个工作合并
#定义训练数据集
train_data=datasets.MNIST(
    root='./data',
    download=True,#如果本地没有,就会自动将数据集下载到data目录下(data目录自动生成)
    train=True,#训练集
    transform=transform#预处理
)
#定义测试数据集
test_data=datasets.MNIST(
    root='./data',
    transform=transform,#加工
    download=True,
    train=False
)
#批量加载训练数据集
train_loader=DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,#训练打乱
    num_workers=2
)
#批量加载测试数据集
test_loader=DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=False,#测试不打乱
    num_workers=2#并行进程数
)
if __name__=='__main__':
    #循环次数是批量数=总样本数(60000)/批量大小(64)
    for batch_idx,(data,target) in enumerate(train_loader):#enumerate作用产生索引batch_idx
        print(f'批量大小是{batch_idx}')
        print(f'数据形状{data.shape}')
        print(f'标签{target.shape}')

三.自定义数据集

例:猫狗数据集

提前准备好猫狗图片以及csv文件记录图片名和标签

image

image

image

import torch
import os#系统路径
import pandas as pd
from PIL import Image #处理图片
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 1. 自定义Dataset类
class CustomImageDataset(Dataset):#这三个方法要写
    def __init__(self, img_dir, label_file, transform=None):
        self.img_dir = img_dir  # 图片文件夹路径
        self.labels = pd.read_csv(label_file)  # 读取标签文件,labels实例属性
        self.transform = transform  # 预处理方法

    def __len__(self):
        # 返回数据集总样本数
        return len(self.labels)

    def __getitem__(self, idx):
        # 根据索引idx返回单条数据和标签
        img_name = self.labels.iloc[idx, 0]  # 第idx行第0列:文件名
        img_path = os.path.join(self.img_dir, img_name)  # 图片完整路径
        image = Image.open(img_path).convert('RGB')  # 用PIL打开图像,并强制转为RGB三通道(避免灰度图通道不一致)
        label = self.labels.iloc[idx, 1]  # 第idx行第1列:标签

        # 应用预处理(若有)
        if self.transform:
            image = self.transform(image)

        return image, label  # 返回(数据,标签)

# 2. 定义预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 缩放为224x224
    transforms.ToTensor(),          # 转为Tensor
    #为什么均值、方差都是三维?因为转为rgb图像
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet均值/标准差
])

# 3. 实例化自定义Dataset
custom_dataset = CustomImageDataset(
    img_dir=r'F:\masterleaning\custom_data\images',
    label_file=r'F:\masterleaning\custom_data\labels.csv',
    transform=transform
)

# 4. 用DataLoader加载
custom_loader = DataLoader(
    dataset=custom_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2
)
if __name__=='__main__':
# 5. 迭代验证
    for data, labels in custom_loader:
 #批量大小batch_size=labels.shape因为一批多少样本就要打多少标签
        print(f"数据形状: {data.shape}")  # (32, 3, 224, 224)(batch_size, 通道, 高, 宽)
        print(f"标签形状: {labels.shape}")  # (32,)
        break

 

posted @ 2025-11-12 14:47  Annaprincess  阅读(48)  评论(0)    收藏  举报