13.6.0 头文件
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from matplotlib import pyplot as plt
13.6.1 下载banana-detection训练集和测试集,并读入内存,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集
# 训练集banana-detection的下载地址和加密签名
d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de26c8fce5ccdea9f91267273464dc968d20d72')
# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
# 下载'banana-detection'数据集保存到data文件夹下,并对其进行解压,返回数据集文件路径(..\data\banana-detection)
data_dir = d2l.download_extract('banana-detection')
# csv_fname:训练集或测试集的标签文件
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')
# 读入标签文件
csv_data = pd.read_csv(csv_fname)
# 将img_name列设置为索引
csv_data = csv_data.set_index('img_name')
# images用来存放训练集或数据集的图像集合,targets用来存放训练集或数据集的标签集合包含(类别,左上角x,左上角y,右下角x,右下角y)
images, targets = [], []
# csv_data.iterrows():行迭代器,返回行索引和行对象
for img_name, target in csv_data.iterrows():
# 打包图像集合
images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))
# 打包标签集合
targets.append(list(target))
# 返回训练姐或测试集的图像集合和标签集合(标签集合升维,表示一张图片中可以存在多件物品)
return images, torch.tensor(targets).unsqueeze(1) / 256
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
# 下载'banana-detection'训练集和测试集,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集
def load_data_bananas(batch_size):
# 下载'banana-detection'训练集,并读入到内存中,然后按照批量大小对训练集进行切割并封装
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)
# 下载'banana-detection'测试集,并读入到内存中,然后按照批量大小对测试集进行切割并封装
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)
# 返回训练集和测试集
return train_iter, val_iter
# 定义批量大小
batch_size, edge_size = 32, 256
# 下载'banana-detection'训练集,然后按照批量大小对训练集进行切割并封装,返回训练集
train_iter, _ = load_data_bananas(batch_size)
# 获取训练集中第一个批次的样本
batch = next(iter(train_iter))
# 获取一个批次样本中图像集合的形状、一个批次样本中标签集合的形状
print(batch[0].shape, batch[1].shape)
# 输出:
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])
13.6.2 显示一个批次中前10个样本,并框出每个样本中的香蕉
# 显示一个批量的前10个样本图像
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
# 框处每个样本图像中的香蕉
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()

本小节完整代码如下
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from matplotlib import pyplot as plt
# ------------------------------下载banana-detection训练集和测试集,并读入内存,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集------------------------------------
# 训练集banana-detection的下载地址和加密签名
d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de26c8fce5ccdea9f91267273464dc968d20d72')
# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
# 下载'banana-detection'数据集保存到data文件夹下,并对其进行解压,返回数据集文件路径(..\data\banana-detection)
data_dir = d2l.download_extract('banana-detection')
# csv_fname:训练集或测试集的标签文件
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')
# 读入标签文件
csv_data = pd.read_csv(csv_fname)
# 将img_name列设置为索引
csv_data = csv_data.set_index('img_name')
# images用来存放训练集或数据集的图像集合,targets用来存放训练集或数据集的标签集合包含(类别,左上角x,左上角y,右下角x,右下角y)
images, targets = [], []
# csv_data.iterrows():行迭代器,返回行索引和行对象
for img_name, target in csv_data.iterrows():
# 打包图像集合
images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))
# 打包标签集合
targets.append(list(target))
# 返回训练姐或测试集的图像集合和标签集合(标签集合升维,表示一张图片中可以存在多件物品)
return images, torch.tensor(targets).unsqueeze(1) / 256
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
# 下载'banana-detection'训练集或测试集,并读入到内存中,返回图像集合和标签集合
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
# 下载'banana-detection'训练集和测试集,然后按照批量大小分别对训练集和测试集进行切割并封装,返回训练集和测试集
def load_data_bananas(batch_size):
# 下载'banana-detection'训练集,并读入到内存中,然后按照批量大小对训练集进行切割并封装
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)
# 下载'banana-detection'测试集,并读入到内存中,然后按照批量大小对测试集进行切割并封装
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)
# 返回训练集和测试集
return train_iter, val_iter
# 定义批量大小
batch_size, edge_size = 32, 256
# 下载'banana-detection'训练集,然后按照批量大小对训练集进行切割并封装,返回训练集
train_iter, _ = load_data_bananas(batch_size)
# 获取训练集中第一个批次的样本
batch = next(iter(train_iter))
# 获取一个批次样本中图像集合的形状、一个批次样本中标签集合的形状
print(batch[0].shape, batch[1].shape)
# 输出:
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])
# ------------------------------显示一个批次中前10个样本,并框出每个样本中的香蕉------------------------------------
# 显示一个批量的前10个样本图像
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
# 框处每个样本图像中的香蕉
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()
浙公网安备 33010602011771号