14 torchvision的数据集使用

一、pytorch官网

pytorch官网

  1. 基本模块
    image

  2. torchvision

  • 数据集

image

  • 模型

image

数据集基本参数:

image

image

  • transform和工具类
    image

二、实验阶段

  • 数据集的下载

如果通过代码下载不下来,可以自己下好(ctrl+对应数据集的名称,进入到数据集的方法里,找到对应链接),放到对应的目录下,再运行,或者用迅雷

下载链接:
image

  • 解析数据集:

image

  • 查看数据集:
#查看数据集
print(test_set[0])

相当于把分类对应成数字
image

debug之后可以看到所有的分类:
image

  • 完整代码
import torchvision
from torch.utils.tensorboard import SummaryWriter

# 4.设置数据集transform
dataset_transform=torchvision.transforms.Compose({
    torchvision.transforms.ToTensor()
})
# 4.1 添加transform参数
train_set=torchvision.datasets.CIFAR10(root="./CIFAR10_dataset",transform=dataset_transform,train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./CIFAR10_dataset",transform=dataset_transform,train=False,download=True)

# # 1. 查看数据集
# print(test_set[0])
# # 2. 输出所有的类别
# print(test_set.classes)
#
# # 3. 分别输出图片信息和类别
# img,target=test_set[0]
# print(img)
# print(target)#数字版的类别
# print(test_set.classes[target])#文字版的类别
# img.show()#输出图片

# 4.2 输出图片,用tensorboard,需要注释1,2,3
writer=SummaryWriter("logs14")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("test_set",img,i)

writer.close()

运行结果:
image

posted @ 2022-05-12 12:29  Trouvaille_fighting  阅读(106)  评论(0)    收藏  举报