Follow me on GitHub

torchvision 批量可视化图片

1.1 简介

计算机视觉中,我们需要观察我们的神经网络输出是否合理。因此就需要进行可视化的操作。

orchvision是独立于pytorch的关于图像操作的一些方便工具库。

torchvision的详细介绍在:https://pypi.org/project/torchvision/0.1.8/

这里主要使用的是make_grid函数,参数的tensor是一个 (B x C x H x W) - (Batchsize, Channel, Heigjt, Weight)的张量,nrow是输出图片网格的列数。padding是每张图片之间宽度间隔。

make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)

Example usage is given in this notebook<https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>

举个例子。如果你的batch size 是一个(32,3,256,256)的一组图片,设置为nrow = 8,则最后输出的图片是一个4*8的网格,每个网格是一张图片。

 

2.1 代码

batch_image是([5, 3, 256, 256])大小的张量。

batch_labels是 ([5, 15, 2]) 的坐标点。用于标记每张图中15个关键点的 [x, y] 坐标

vis_flipped ([1, 5, 14]) 记录每个关键点可见的情况,0为不可见,1为可见

output_root 是保存图片的路径

i_loader是data loader 的索引

j_loader是batch的索引

代码的关键是要保存正确的关键的信息在一个大网格内,因此,需要把每个关键点的坐标,写一个for 循环。

x = 行数*图片宽 + padding +x ,

y = 列数*图片高 + padding +y

import cv2
import os
import torchvision
import numpy as npdef save_visualize_result(batch_image,labels,batch_labels,raw_image,vis_flipped,output_root,i_loader,j_batch):
    # batch_image.shape ([5, 3, 256, 256])
    # labels.shape ([1, 5, 15, 31, 31])
    # batch_labels.shape ([5,15,2])
    # raw_image.shape ([ 5, 3 , width_raw,height_raw ])
    # flipped_labels.shape ([1,5,28])[x1,x2,x3 ...,x14,y1,y2,y3...y14
    # vis_flipped [1, 5, 14]
    # i_loader -- which loader, j_batch -- which_batch


    batch_size, n_stages, n_joints = labels.shape[0], labels.shape[1], labels.shape[2]
    xmaps = n_stages
    ymaps = batch_size

    image_size = batch_image.shape[-2]
    label_size = labels.shape[-2]
    rotation = image_size / label_size

    grid = torchvision.utils.make_grid(batch_image, nrow=n_stages, padding=2, normalize=True)
    ndarr = grid.mul(255).clamp(0, 255).byte().cpu().permute(1, 2, 0).numpy()
    b, g, r = cv2.split(ndarr)

    ndarr = cv2.merge([r, g, b])
    ndarr = ndarr.copy()
    
    padding = 2

    height = int(batch_image.size(2) + padding)
    width = int(batch_image.size(3) + padding)
    k = 0
    # mpii_order = [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2]
    # transformed order [13, 11,  9,  8, 10, 12,  4,  6, 14,  1,  7,  5,  3,  2]
    names = ['ra', 'rk', 'rh', 'lh', 'lk', 'la', 'le', 'lw', 'neck', 'head', 'rw', 're', 'rs', 'ls']

    ### mapped ###
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            raw_vis = vis_flipped[0, k, :]
            joints = batch_labels[k, :, :] * rotation
            for i_name, joint in enumerate(joints):
                if i_name < 14:
                    if raw_vis[i_name] == 0:
                        continue
                    joint[0] = x * width + padding + joint[0]
                    joint[1] = y * height + padding + joint[1]
                    cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
                    cv2.putText(ndarr, names[i_name], org=(int(joint[0]), int(joint[1])),
                                fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=[0, 0, 255])
            k = k + 1
    cv2.imwrite(os.path.join(output_root, 'loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png'), ndarr)
    print('loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png' + 'saved successfuly!')
 

 

3.1 结果

 

 

posted @ 2019-04-14 00:07  SiyuanChen  阅读(2500)  评论(0编辑  收藏  举报