Pytorch可视化热力图

 

可视化热力图可以有两种方式:

1)特征图可视化,将各通道特征的最大值作为热力图像素值,进行可视化——可以参考博客,一种比较灵活的特征图保存方式

2)根据梯度值结合特征图计算热力图,热力图的显示的重点是梯度高的地方,也是网络关注的重点

 

基于梯度进行热力图可视化有一些工作,如grad-cam,也有一些开发好的脚本,不过这些脚本不具有通用性,

因此此处基于torch的hook机制进行可视化,是一种基础并且通用性很好的策略,很容易在自己的模式上进行尝试。

 

代码的逻辑结构如下:

首先定义模型,加载权重,再对想要可视化的网络层进行hook注册,接下来推理模型并进行梯度反传即可,

farward_hook,backward_hook会自动获取对应的特征图和反传梯度,后面处理并保存到本地

# coding: utf-8
import cv2
import os
import torch
import numpy as np


def img_preprocess(img_in):
    pass

# get activate map
def backward_hook(module, grad_in, grad_out):
    grad_block.append(grad_out[0].detach())

# get gradient map 
def farward_hook(module, input, output):
    fmap_block.append(output)

# apply color to heatmap and save the result
def show_cam_on_image(img, mask, out_dir):
    h, w, _ = img.shape
    heatmap = cv2.resize(heatmap, (w, h))
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    img = np.float32(img) / 255         # make sure pixel value will not be bigger than 256 after add
    cam = heatmap + np.float32(img)     # show heatmap in original image
    cam = cam / np.max(cam)

    path_cam_img = os.path.join(out_dir, "cam.jpg")
    path_raw_img = os.path.join(out_dir, "raw.jpg")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    cv2.imwrite(path_cam_img, np.uint8(255 * cam))
    cv2.imwrite(path_raw_img, np.uint8(255 * img))


def gen_cam(feature_map, grads):
    """
    依据梯度和特征图,生成cam
    :param feature_map: np.array, in [C, H, W]
    :param grads: np.array, in [C, H, W]
    :return: np.array, [H, W]
    """
    cam = np.zeros(feature_map.shape[1:], dtype=np.float32)  # cam shape (H, W)

    weights = np.mean(grads, axis=(1, 2))  #

    for i, w in enumerate(weights):
        cam += w * feature_map[i, :, :]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (32, 32))
    cam -= np.min(cam)
    cam /= np.max(cam)

    return cam


if __name__ == '__main__':

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    path_img, path_net, output_dir = None, None, None # change to yours 

    fmap_block = list()
    grad_block = list()

    # 图片读取;网络加载
    img = cv2.imread(path_img, 1)  # H*W*C
    img_input = img_preprocess(img)
    model = ResNet50()
    model.load_state_dict(torch.load(path_net))

    # 注册 hook
    model.layer4.register_forward_hook(farward_hook)    # get activate map
    model.layer4.register_backward_hook(backward_hook)  # get gradient map

    # forward
    output = model(img_input)   # model.training is True

    # backward
    model.zero_grad()
    loss = model.get_loss(output)
    loss.backward()

    # 生成cam
    grads_val = grad_block[0].cpu().data.numpy().squeeze()
    fmap = fmap_block[0].cpu().data.numpy().squeeze()
    cam = gen_cam(fmap, grads_val)

    # 保存cam图片
    show_cam_on_image(img, cam, output_dir)

——主体参考代码

posted @ 2023-04-24 17:50  谷小雨  阅读(950)  评论(0编辑  收藏  举报