神经网络特征图的可视化
神经网络的可视化过程
特征图可视化
""""
神经网络的可视化过程
"""
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np
import cv2
from torchsummary import summary
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
nn.MaxPool2d(2),
nn.MaxPool2d(2),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
path = "2.png"
trans = transforms.Compose([transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
x = Image.open(path).convert("RGB")
x = trans(x)
x = torch.unsqueeze(x, 0) # 填充一维
modol = Test()
y = modol(x) # y = [1, 8, 112, 112]([N, C, W, H])
#特征图的可视化
def map_feature(img, images_per_row):
tt = img.detach().numpy()
layer_names = []
for layer in modol._modules.items():
layer_names.append(layer[0])
for layer_name, layer_activation in zip(layer_names, tt):
n_features = layer_activation.shape[0] # 8
size = layer_activation.shape[1] # 112
n_cols = n_features // images_per_row # 2
display_grid = np.zeros((size * n_cols, images_per_row * size)) # [112*2, 112*4]
for col in range(n_cols):
for row in range(images_per_row):
channel_image = layer_activation[col * images_per_row + row, :, :]
channel_image -= channel_image.mean()
channel_image /= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype('uint8')
display_grid[col * size: (col + 1) * size, row * size: (row + 1) * size] = channel_image
scale = 1. / size
plt.figure(figsize=(scale * display_grid.shape[1],
scale * display_grid.shape[0]))
plt.title(layer_name)
plt.grid(False)
plt.imshow(display_grid, aspect='auto', cmap='viridis')
plt.savefig(layer_name + ".png")
plt.show()
map_feature(img = y, images_per_row = 8)
代码无注释,哪句有问题,欢迎留言,顺便给个关注。
浙公网安备 33010602011771号