2023.3.30学习记录
P11:torchvision中数据集的使用
#torchvision中数据集的使用
#导入torchvision包
#从torch.utils.tensorboard 这个工具箱中导入SummaryWriter工具
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
#使用torchvision包中的transforms工具箱中的Compose方法(里面为transform中的Totensor方法)
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
#从torchvision包中的datasets工具箱中下载CIFAR10数据集(root参数为数据集下载的根目录,transform参数为选择何种transform,train参数为下载的为训练集(train)还是测试集(False))
train_set = torchvision.datasets.CIFAR10(root="./dataset", transform=dataset_transform, train=True, download=True)#训练集
test_set = torchvision.datasets.CIFAR10(root="./dataset", transform=dataset_transform, train=False, download=True)#测试集
print(test_set[0]) #本数据集共有50000张训练集,10000张测试集,testset[0]为测试集的第一张
print(train_set.classes)#类别名存放在classes列表中
img,target = test_set[0]#img为3*32*32的单张图片 ,target为图片的目标类别
print(img)#输出为Tosenor数据类型的图片
print(target)#输出为该图片的目标类别
print(test_set.classes[target])#输出目标类别的具体类别名
print("1")#标识符
# img.show()
writer = SummaryWriter("p11logs")#新建日志文件
#循环赋值目标图片与目标类别,并使用TensorBorad展示
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()#关闭日志
P12:DataLoader的使用
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#准备测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
#dataset参数表示数据集的位置,batch_size表示每次取多少张图片,shuffle表示每轮取图片的顺序是否相同(True是不同,False是相同,drop_last表示余下的最后一次batch是否采用,False表示采用)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
#测试数据集的第一张图片及target
img, target = test_data[0]
# print(img.shape)
# print(target)
writer = SummaryWriter("dataloader")
# 循环读取图片并显示在TensorBoard中
for epoch in range(2):
step = 0
for data in test_loader:
imgs, target = data
writer.add_images("Epoch{}_False".format(epoch), imgs, step)
step = step + 1
writer.close()
浙公网安备 33010602011771号