13.6.0 头文件

 

import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from matplotlib import pyplot as plt

 

 

 

13.6.1 下载banana-detection训练集和测试集,并读入内存,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集

 

# 训练集banana-detection的下载地址和加密签名
d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de26c8fce5ccdea9f91267273464dc968d20d72')

# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    # 下载'banana-detection'数据集保存到data文件夹下,并对其进行解压,返回数据集文件路径(..\data\banana-detection)
    data_dir = d2l.download_extract('banana-detection')
    # csv_fname:训练集或测试集的标签文件
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')
    # 读入标签文件
    csv_data = pd.read_csv(csv_fname)
    # 将img_name列设置为索引
    csv_data = csv_data.set_index('img_name')
    # images用来存放训练集或数据集的图像集合,targets用来存放训练集或数据集的标签集合包含(类别,左上角x,左上角y,右下角x,右下角y)
    images, targets = [], []
    # csv_data.iterrows():行迭代器,返回行索引和行对象
    for img_name, target in csv_data.iterrows():
        # 打包图像集合
        images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))
        # 打包标签集合
        targets.append(list(target))
    # 返回训练姐或测试集的图像集合和标签集合(标签集合升维,表示一张图片中可以存在多件物品)
    return images, torch.tensor(targets).unsqueeze(1) / 256

class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        # 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)


# 下载'banana-detection'训练集和测试集,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集
def load_data_bananas(batch_size):
    # 下载'banana-detection'训练集,并读入到内存中,然后按照批量大小对训练集进行切割并封装
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)
    # 下载'banana-detection'测试集,并读入到内存中,然后按照批量大小对测试集进行切割并封装
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)
    # 返回训练集和测试集
    return train_iter, val_iter

# 定义批量大小
batch_size, edge_size = 32, 256
# 下载'banana-detection'训练集,然后按照批量大小对训练集进行切割并封装,返回训练集
train_iter, _ = load_data_bananas(batch_size)
# 获取训练集中第一个批次的样本
batch = next(iter(train_iter))
# 获取一个批次样本中图像集合的形状、一个批次样本中标签集合的形状
print(batch[0].shape, batch[1].shape)
# 输出:
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])

 

 

 

13.6.2 显示一个批次中前10个样本,并框出每个样本中的香蕉

 

# 显示一个批量的前10个样本图像
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
# 框处每个样本图像中的香蕉
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()

 

 

 

 

本小节完整代码如下

 

import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from matplotlib import pyplot as plt

# ------------------------------下载banana-detection训练集和测试集,并读入内存,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集------------------------------------

# 训练集banana-detection的下载地址和加密签名
d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de26c8fce5ccdea9f91267273464dc968d20d72')

# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    # 下载'banana-detection'数据集保存到data文件夹下,并对其进行解压,返回数据集文件路径(..\data\banana-detection)
    data_dir = d2l.download_extract('banana-detection')
    # csv_fname:训练集或测试集的标签文件
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')
    # 读入标签文件
    csv_data = pd.read_csv(csv_fname)
    # 将img_name列设置为索引
    csv_data = csv_data.set_index('img_name')
    # images用来存放训练集或数据集的图像集合,targets用来存放训练集或数据集的标签集合包含(类别,左上角x,左上角y,右下角x,右下角y)
    images, targets = [], []
    # csv_data.iterrows():行迭代器,返回行索引和行对象
    for img_name, target in csv_data.iterrows():
        # 打包图像集合
        images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))
        # 打包标签集合
        targets.append(list(target))
    # 返回训练姐或测试集的图像集合和标签集合(标签集合升维,表示一张图片中可以存在多件物品)
    return images, torch.tensor(targets).unsqueeze(1) / 256

class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        # 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)


# 下载'banana-detection'训练集和测试集,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集
def load_data_bananas(batch_size):
    # 下载'banana-detection'训练集,并读入到内存中,然后按照批量大小对训练集进行切割并封装
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)
    # 下载'banana-detection'测试集,并读入到内存中,然后按照批量大小对测试集进行切割并封装
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)
    # 返回训练集和测试集
    return train_iter, val_iter

# 定义批量大小
batch_size, edge_size = 32, 256
# 下载'banana-detection'训练集,然后按照批量大小对训练集进行切割并封装,返回训练集
train_iter, _ = load_data_bananas(batch_size)
# 获取训练集中第一个批次的样本
batch = next(iter(train_iter))
# 获取一个批次样本中图像集合的形状、一个批次样本中标签集合的形状
print(batch[0].shape, batch[1].shape)
# 输出:
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])

# ------------------------------显示一个批次中前10个样本,并框出每个样本中的香蕉------------------------------------

# 显示一个批量的前10个样本图像
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
# 框处每个样本图像中的香蕉
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()

 

posted on 2022-11-27 15:26  yc-limitless  阅读(424)  评论(0)    收藏  举报