【pytorch】制作网格图像,直接将tensor格式的图像保存到本地
写到前面
这是torchvision.utils模块里面的两个方法,因为比较常用,所以pytorch直接封装好了。
制作网格
网络图像一般用于训练数据或测试数据的可视化。
torchvision.utils.make_grid(tensor, nrow, padding) → torch.Tensor
- 描述
将多张tensor格式的图像以网格的方式封装到一起。
- 参数
tensor (tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列表。
nrow (int):网格每行图片的个数,默认是8;千万不要理解为图片的行数。
padding (int):四周填充的宽度,默认是2,你可以理解为网格中图片之间的间距。默认填充值是0,也就是黑色。
注:这是三个比较常用的参数,其它参数请参考官方文档。
- 示例
# 以mnist数据集为例,train_loader的batch_size设置为9
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
- 绘图
![在这里插入图片描述]()
保存本地
tensor数据类型保存时不用再转为PIL.Image或numpy.ndarray,pytorch直接给我们写好了一个方法。
torchvision.utils.save_image(tensor, fp) → None
- 描述
直接将tensor数据保存为图像。
- 参数
tensor (Tensor or list):待保存的tensor数据。如果给以一个四维的mini-batch的tensor,将调用网格方法,然后再保存到本地。
fp (string or file object)):图像的保存路径。
注:这是两个比较常用的参数,其它参数请参考官方文档。
- 示例
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
torchvision.utils.save_image(images, 'test.jpg')
完整代码
#%% 导入模块
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
#%% 下载数据集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
save_image(images, 'test.jpg')


浙公网安备 33010602011771号