torchvision 批量可视化图片
1.1 简介
计算机视觉中,我们需要观察我们的神经网络输出是否合理。因此就需要进行可视化的操作。
orchvision是独立于pytorch的关于图像操作的一些方便工具库。
torchvision的详细介绍在:https://pypi.org/project/torchvision/0.1.8/
这里主要使用的是make_grid函数,参数的tensor是一个 (B x C x H x W) - (Batchsize, Channel, Heigjt, Weight)的张量,nrow是输出图片网格的列数。padding是每张图片之间宽度间隔。
make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
Example usage is given in this notebook<https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>
举个例子。如果你的batch size 是一个(32,3,256,256)的一组图片,设置为nrow = 8,则最后输出的图片是一个4*8的网格,每个网格是一张图片。
2.1 代码
batch_image是([5, 3, 256, 256])大小的张量。
batch_labels是 ([5, 15, 2]) 的坐标点。用于标记每张图中15个关键点的 [x, y] 坐标
vis_flipped ([1, 5, 14]) 记录每个关键点可见的情况,0为不可见,1为可见
output_root 是保存图片的路径
i_loader是data loader 的索引
j_loader是batch的索引
代码的关键是要保存正确的关键的信息在一个大网格内,因此,需要把每个关键点的坐标,写一个for 循环。
x = 行数*图片宽 + padding +x ,
y = 列数*图片高 + padding +y
import cv2 import os import torchvision import numpy as npdef save_visualize_result(batch_image,labels,batch_labels,raw_image,vis_flipped,output_root,i_loader,j_batch): # batch_image.shape ([5, 3, 256, 256]) # labels.shape ([1, 5, 15, 31, 31]) # batch_labels.shape ([5,15,2]) # raw_image.shape ([ 5, 3 , width_raw,height_raw ]) # flipped_labels.shape ([1,5,28])[x1,x2,x3 ...,x14,y1,y2,y3...y14 # vis_flipped [1, 5, 14] # i_loader -- which loader, j_batch -- which_batch batch_size, n_stages, n_joints = labels.shape[0], labels.shape[1], labels.shape[2] xmaps = n_stages ymaps = batch_size image_size = batch_image.shape[-2] label_size = labels.shape[-2] rotation = image_size / label_size grid = torchvision.utils.make_grid(batch_image, nrow=n_stages, padding=2, normalize=True) ndarr = grid.mul(255).clamp(0, 255).byte().cpu().permute(1, 2, 0).numpy() b, g, r = cv2.split(ndarr) ndarr = cv2.merge([r, g, b]) ndarr = ndarr.copy() padding = 2 height = int(batch_image.size(2) + padding) width = int(batch_image.size(3) + padding) k = 0 # mpii_order = [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2] # transformed order [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2] names = ['ra', 'rk', 'rh', 'lh', 'lk', 'la', 'le', 'lw', 'neck', 'head', 'rw', 're', 'rs', 'ls'] ### mapped ### k = 0 for y in range(ymaps): for x in range(xmaps): raw_vis = vis_flipped[0, k, :] joints = batch_labels[k, :, :] * rotation for i_name, joint in enumerate(joints): if i_name < 14: if raw_vis[i_name] == 0: continue joint[0] = x * width + padding + joint[0] joint[1] = y * height + padding + joint[1] cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2) cv2.putText(ndarr, names[i_name], org=(int(joint[0]), int(joint[1])), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=[0, 0, 255]) k = k + 1 cv2.imwrite(os.path.join(output_root, 'loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png'), ndarr) print('loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png' + 'saved successfuly!')
3.1 结果