matplotlib绘制批次张量中的图片

matplotlib绘制批次张量中的图片

  1. 首先需要导入matplotlib.pyplot这个包

  2. 然后设置四个变量batch_size, channels, height, width接受批次张量的信息

  3. 使用 plt.subplots() 创建一个具有指定行数和列数的图形网格,并设置画布大小为 (10, 10)。plt.subplots() 是 Matplotlib 库中用于创建图形网格的函数。它的作用是创建一个包含多个子图(axes)的图形网格,并返回一个包含整个图形(fig)和子图数组的元组(axes)。

  4. 然后就是遍历张量batch_images,并转换为numpy格式,因为Matplotlib 中通常期望图像的通道顺序是 [ height, width ,channels]所以需要permute(1, 2, 0)来交换通道维度,而且Matplotlib 中的 imshow() 函数通常接受 NumPy 数组作为输入来显示图像。因此,需要将张量还需要转换为 NumPy 数组以进行显示。

  5. 然后就可以调用imshow,绘制子图了

import matplotlib.pyplot as plt
# 假设 batch_images 是你的批次张量,形状为 (batch_size, channels, height, width)
batch_size, channels, height, width = batch_images.shape

# 创建一个图像网格,每行显示一张图像
rows = 4  # 设置行数
cols = 4  # 设置列数
fig, axes = plt.subplots(rows, cols, figsize=(10, 10))

# 在图像网格中显示图像
for i in range(rows):
    for j in range(cols):
        # 计算当前图像在批次张量中的索引
        index = i * cols + j
        if index < batch_size:
            # 获取当前图像的张量
            img_tensor = batch_images[index]
            # 将张量转换为 NumPy 数组,并交换通道维度
            img_array = img_tensor.permute(1, 2, 0).numpy()
            # 显示图像
            axes[i, j].imshow(img_array)
            axes[i, j].axis("off")  # 关闭坐标轴
            axes[i, j].set_title(f"Image {index}")  # 设置标题
        else:
            # 如果索引超出了批次大小,显示空白图像
            axes[i, j].axis("off")  # 关闭坐标轴

plt.tight_layout()  # 调整子图之间的间距
plt.show()  # 显示图像网格

image-20240207003335813

注意

  • fig 是一个代表整个图形(Figure)对象的变量。Figure 是一个空白的画布,它是所有子图的容器。你可以在这个画布上绘制图形,或者添加子图。fig 可以控制整个图形的属性,比如大小、标题等。
  • axes 是一个包含子图(Axes)对象的数组。子图对象是具体的绘图区域,它包含了绘图的大部分元素,比如坐标轴、数据、标签等。在 Matplotlib 中,大多数的绘图函数都是在 Axes 对象上操作的。axes 数组的维度与 plt.subplots() 中指定的行数和列数相对应,每个元素代表图形网格中的一个子图。
posted @ 2024-02-07 00:42  微放  阅读(59)  评论(0)    收藏  举报