深度学习笔记007PictureClassificationDataSet图片分类数据集

跟着李沐老师学的,但是今天有点腐败,没学多少,笔记如下:

(关于这堂课之前十几分钟讲的Softmax回归的笔记放在下一篇博客里)

 1 import torch
 2 import torchvision
 3 from torch.utils import data
 4 from torchvision import transforms
 5 from d2l import torch as d2l
 6 
 7 import pylab
 8 
 9 d2l.use_svg_display()
10 
11 #下载数据集,跑一次即可(把download设为false)
12 trans=transforms.ToTensor()
13 mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=False)
14 mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=False)
15 print(len(mnist_test),len(mnist_train))
16 
17 # 将数据集的文本标签返回
18 def get_fashion_mnist_labels(labels):
19     text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
20     return [text_labels[int(i)] for i in labels]
21 '''
22 这里加一段对return的语法知识补充:
23 列表理解:
24 list2 = [int(i) for i in list1.split(' ')]
25 相当于for循环(就功能而言):
26 list2 = []
27 for i in list1.split(' '):
28     list2.append(int(i))
29 '''
30 def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
31     """绘制图像列表"""
32     figsize = (num_cols * scale, num_rows * scale)
33     _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
34     axes = axes.flatten()
35     for i, (ax, img) in enumerate(zip(axes, imgs)):
36         if torch.is_tensor(img):
37             ax.imshow(img.numpy())
38         else:
39             ax.imshow(img)
40         ax.axes.get_xaxis().set_visible(False)
41         ax.axes.get_yaxis().set_visible(False)
42         if titles:
43             ax.set_title(titles[i])
44     return axes
45 
46 X,y=next(iter(data.DataLoader(mnist_train,batch_size=18)))
47 show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))
48 
49 batch_size=256
50 def get_dataloader_workers():
51     return 4
52 train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
53 timer=d2l.Timer()
54 for X,y in train_iter:
55     continue
56 print(f'{timer.stop():.2f} sec') #读取数据要比计算快才行
57 
58 
59 def load_data_fashion_mnist(batch_size, resize=None):  #@save
60     """下载Fashion-MNIST数据集,然后将其加载到内存中"""
61     dataset = gluon.data.vision
62     trans = [dataset.transforms.ToTensor()]
63     if resize:
64         trans.insert(0, dataset.transforms.Resize(resize))
65     trans = dataset.transforms.Compose(trans)
66     mnist_train = dataset.FashionMNIST(train=True).transform_first(trans)
67     mnist_test = dataset.FashionMNIST(train=False).transform_first(trans)
68     return (gluon.data.DataLoader(mnist_train, batch_size, shuffle=True,
69                                   num_workers=get_dataloader_workers()),
70             gluon.data.DataLoader(mnist_test, batch_size, shuffle=False,
71                                   num_workers=get_dataloader_workers()))
72 
73 
74 pylab.show()

 

posted @ 2022-01-03 23:42  爱和九九  阅读(108)  评论(0)    收藏  举报