Saliency map实现

import PIL, torch, torchvision
import matplotlib.pyplot as plt
import sys
import pandas as pd

# 标准化
def normalize(image):
     return (image - image.min()) / (image.max() - image.min())


def show_saliency_map(img_path, model, size=100, cmap=plt.cm.hot):
#     evaluate模式
     model.eval()
     
#     图像变换
     aug1 = torchvision.transforms.Compose(
         [torchvision.transforms.Resize((size, size)),
          torchvision.transforms.ToTensor()])
     aug2 = torchvision.transforms.Resize((size, size))
     aug3 = torchvision.transforms.ToPILImage()

#     读取一张图片
     img = PIL.Image.open(img_path)
     img = img.convert("RGB")
#     变换
     timg = aug1(img).view(1, 3, size, size)
#     梯度
     timg.requires_grad = True

#     正向传播得到output
     output = model(timg)
#     获取预测概率最大的index
     timg_class = output.argmax(dim=1).item()

#     1000类dict
     pd_data = pd.read_csv('./1000class_dict.csv')
     
     pd_data_en = pd_data.iloc[:, 3]
     class_index_en = pd_data_en.to_dict()
     
     pd_data_zh = pd_data.iloc[:, 2]
     class_index_zh = pd_data_zh.to_dict()
     
     print(class_index_zh[timg_class],class_index_en[timg_class])

#     找到output的对应fc输出单元
     s = output[0, timg_class]
#     反向传播求此单元梯度
     s.backward()

    with torch.no_grad():
#         得到了梯度
         grad = timg.grad.data[0]
#         对梯度图处理,取绝对值,求像素通道最大值
         graph = torch.max(torch.abs(grad), dim=0)[0]  # [0]是max_value  [1]是max_index
         lambd = 0.1
#         paper中的方法
         saliency_map_gray = (graph - lambd * (torch.norm(timg, 2) ** 2).item()).numpy()
         
#         直接梯度求绝对值
         saliency_map_rgb = timg.grad.abs().cpu()
#         将每个通道归一化
         saliency_map_rgb = torch.stack([normalize(item) for item in saliency_map_rgb])

    fig, ax = plt.subplots(1, 3)
     raw_img = aug2(img)
     ax[0].imshow(raw_img)
     ax[0].set_title(class_index_en[timg_class])
     
     rgb_saliency = aug3(saliency_map_rgb.view(3, size, size))
     ax[1].imshow(rgb_saliency)
     ax[1].set_title('RGB map')
     ax[2].imshow(saliency_map_gray, cmap=cmap)
     ax[2].set_title('gray map')
     plt.show()

img = './panda.png'
model = torchvision.models.resnet18(pretrained=True)
show_saliency_map(img, model, size = 224)

参考:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps , https://arxiv.org/abs/1312.6034

posted @ 2020-11-15 19:27  Mydrizzle  阅读(393)  评论(0编辑  收藏  举报