pytorch 和 matplotlib 从 torchvision FashionMNIST 数据集中显示图像

matplotlib 从 torchvision FashionMNIST 数据集中显示图像

如何导入 mnist_train 对象可以参考 https://www.cnblogs.com/fanbal/p/19196275

重要片段

实现效果:
image

核心的片段如下:

tensor1: torch.Tensor = mnist_train[0][0]  
plt.imshow(tensor1.squeeze()) # squeeze() 可以让原先 1x28x28的shape结构变成 28x28,符合图片的尺寸,这样就可以在 plot 中呈现了。
plt.show() # 一定要写,如果不写的话就不会呈现

全部片段

import torch
import numpy as np
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root=r"data",
    train=True,
    transform=trans,
    download=False,
)

tensor1: torch.Tensor = mnist_train[0][0]

plt.imshow(tensor1.squeeze())
plt.show()
posted @ 2025-11-06 13:26  fanbal  阅读(9)  评论(0)    收藏  举报