torch vision中的数据集使用
1.数据集的下载与说明
去pytorch官网找到torchvision进入datasets寻找数据集即可,相关说明如下



2.数据集使用
import torchvision
# 训练数据集
# download 设置为true自动下载,想看下载链接可以跳转目标数据集的函数寻找
train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, download=True)
# 查看训练集
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set[target])
img.show()

可以看出图片确实是frog

3. tensorboard 可视化
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root = "../dataset1", train = True, transform=dataset_transform, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root = "../dataset1", train = True, transform=dataset_transform, download=True)
# 查看训练集
# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set[target])
# img.show()
print(test_set[0]) #可以看到是tensor数据类型
#使用tensorboard可视化
writer = SummaryWriter("p10")
# 展示测试集前十张图片
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()


浙公网安备 33010602011771号