深度学习笔记007PictureClassificationDataSet图片分类数据集
跟着李沐老师学的,但是今天有点腐败,没学多少,笔记如下:
(关于这堂课之前十几分钟讲的Softmax回归的笔记放在下一篇博客里)
1 import torch 2 import torchvision 3 from torch.utils import data 4 from torchvision import transforms 5 from d2l import torch as d2l 6 7 import pylab 8 9 d2l.use_svg_display() 10 11 #下载数据集,跑一次即可(把download设为false) 12 trans=transforms.ToTensor() 13 mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=False) 14 mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=False) 15 print(len(mnist_test),len(mnist_train)) 16 17 # 将数据集的文本标签返回 18 def get_fashion_mnist_labels(labels): 19 text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] 20 return [text_labels[int(i)] for i in labels] 21 ''' 22 这里加一段对return的语法知识补充: 23 列表理解: 24 list2 = [int(i) for i in list1.split(' ')] 25 相当于for循环(就功能而言): 26 list2 = [] 27 for i in list1.split(' '): 28 list2.append(int(i)) 29 ''' 30 def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save 31 """绘制图像列表""" 32 figsize = (num_cols * scale, num_rows * scale) 33 _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) 34 axes = axes.flatten() 35 for i, (ax, img) in enumerate(zip(axes, imgs)): 36 if torch.is_tensor(img): 37 ax.imshow(img.numpy()) 38 else: 39 ax.imshow(img) 40 ax.axes.get_xaxis().set_visible(False) 41 ax.axes.get_yaxis().set_visible(False) 42 if titles: 43 ax.set_title(titles[i]) 44 return axes 45 46 X,y=next(iter(data.DataLoader(mnist_train,batch_size=18))) 47 show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y)) 48 49 batch_size=256 50 def get_dataloader_workers(): 51 return 4 52 train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()) 53 timer=d2l.Timer() 54 for X,y in train_iter: 55 continue 56 print(f'{timer.stop():.2f} sec') #读取数据要比计算快才行 57 58 59 def load_data_fashion_mnist(batch_size, resize=None): #@save 60 """下载Fashion-MNIST数据集,然后将其加载到内存中""" 61 dataset = gluon.data.vision 62 trans = [dataset.transforms.ToTensor()] 63 if resize: 64 trans.insert(0, dataset.transforms.Resize(resize)) 65 trans = dataset.transforms.Compose(trans) 66 mnist_train = dataset.FashionMNIST(train=True).transform_first(trans) 67 mnist_test = dataset.FashionMNIST(train=False).transform_first(trans) 68 return (gluon.data.DataLoader(mnist_train, batch_size, shuffle=True, 69 num_workers=get_dataloader_workers()), 70 gluon.data.DataLoader(mnist_test, batch_size, shuffle=False, 71 num_workers=get_dataloader_workers())) 72 73 74 pylab.show()
 
                    
                     
                    
                 
                    
                
 
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号