matplotlib绘制批次张量中的图片
matplotlib绘制批次张量中的图片
-
首先需要导入matplotlib.pyplot这个包
-
然后设置四个变量batch_size, channels, height, width接受批次张量的信息
-
使用
plt.subplots()
创建一个具有指定行数和列数的图形网格,并设置画布大小为 (10, 10)。plt.subplots()
是 Matplotlib 库中用于创建图形网格的函数。它的作用是创建一个包含多个子图(axes)的图形网格,并返回一个包含整个图形(fig)和子图数组的元组(axes)。 -
然后就是遍历张量batch_images,并转换为numpy格式,因为Matplotlib 中通常期望图像的通道顺序是 [ height, width ,channels]所以需要permute(1, 2, 0)来交换通道维度,而且Matplotlib 中的
imshow()
函数通常接受 NumPy 数组作为输入来显示图像。因此,需要将张量还需要转换为 NumPy 数组以进行显示。 -
然后就可以调用imshow,绘制子图了
import matplotlib.pyplot as plt
# 假设 batch_images 是你的批次张量,形状为 (batch_size, channels, height, width)
batch_size, channels, height, width = batch_images.shape
# 创建一个图像网格,每行显示一张图像
rows = 4 # 设置行数
cols = 4 # 设置列数
fig, axes = plt.subplots(rows, cols, figsize=(10, 10))
# 在图像网格中显示图像
for i in range(rows):
for j in range(cols):
# 计算当前图像在批次张量中的索引
index = i * cols + j
if index < batch_size:
# 获取当前图像的张量
img_tensor = batch_images[index]
# 将张量转换为 NumPy 数组,并交换通道维度
img_array = img_tensor.permute(1, 2, 0).numpy()
# 显示图像
axes[i, j].imshow(img_array)
axes[i, j].axis("off") # 关闭坐标轴
axes[i, j].set_title(f"Image {index}") # 设置标题
else:
# 如果索引超出了批次大小,显示空白图像
axes[i, j].axis("off") # 关闭坐标轴
plt.tight_layout() # 调整子图之间的间距
plt.show() # 显示图像网格
注意
fig
是一个代表整个图形(Figure)对象的变量。Figure 是一个空白的画布,它是所有子图的容器。你可以在这个画布上绘制图形,或者添加子图。fig
可以控制整个图形的属性,比如大小、标题等。axes
是一个包含子图(Axes)对象的数组。子图对象是具体的绘图区域,它包含了绘图的大部分元素,比如坐标轴、数据、标签等。在 Matplotlib 中,大多数的绘图函数都是在 Axes 对象上操作的。axes
数组的维度与plt.subplots()
中指定的行数和列数相对应,每个元素代表图形网格中的一个子图。