reshape
reshape() 是 NumPy 数组的一个方法,用于改变数组的形状,同时不改变其数据。
简单来说,它就像一个魔术师,能在不改变魔方块数量的情况下,将一个魔方重新排列成不同的形状。
reshape() 的基本用法
arr.reshape(new_shape)
arr是你要改变形状的 NumPy 数组。new_shape是一个元组或整数,表示你想要的新形状。
最重要的一点是:新形状必须包含与原始数组相同数量的元素。
import numpy as np
# 创建一个包含 12 个元素的一维数组
arr = np.arange(1, 13)
print(f"原始数组:\n{arr}")
print(f"原始形状: {arr.shape}")
# 将一维数组重塑成一个 3x4 的二维数组
reshaped_arr = arr.reshape(3, 4)
print(f"\n重塑后的数组:\n{reshaped_arr}")
print(f"新形状: {reshaped_arr.shape}")
输出:
原始数组:
[ 1 2 3 4 5 6 7 8 9 10 11 12]
原始形状: (12,)
重塑后的数组:
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
新形状: (3, 4)
在这个例子中,原始数组有 12 个元素,3x4 也是 12,所以重塑成功。
常见的应用场景
在手写数字识别中,reshape() 的应用非常频繁。MNIST 数据集的图像通常是 28x28 像素,但为了进行处理,我们有时需要改变它们的形状。
1. 从二维转一维(展平)
在将图像数据输入到某些机器学习模型(如 SVM 或 KNN)之前,需要将 28x28 的二维图像展平成一个 784 像素的一维向量。
# 假设 img 是一个 28x28 的 NumPy 数组
img = np.zeros((28, 28))
# 展平图像
flattened_img = img.reshape(784)
# 或者使用更常用的方法:
# flattened_img = img.reshape(-1)
reshape(-1) 是一种非常方便的用法。-1 告诉 NumPy 自动计算这个维度的大小,以确保总元素数量不变。当只有一个维度是 -1 时,它会很智能地计算出这个维度应该是什么值。
2. 从一维转多维
在深度学习中,特别是使用卷积神经网络(CNN)时,需要将展平的图像数据重新塑形为四维张量。
# 假设我们有 60000 张展平的图片
X_flattened = np.zeros((60000, 784))
# 将其重塑为四维张量 (样本数, 通道数, 高, 宽)
# 单通道图像通常用 1 表示通道数
X_reshaped = X_flattened.reshape(60000, 1, 28, 28)
在这个例子中,reshape() 将 60000 个 784 像素的向量,重新组织成了 60000 张 1 通道、28x28 像素的图片。
reshape() 的重要特性
- 返回视图:
reshape()通常返回原始数组的一个视图(view)。这意味着新数组和原始数组共享底层数据。如果你修改了新数组中的一个元素,原始数组中的对应元素也会随之改变。 - 不会改变原始数组:
reshape()不会就地(in-place)改变原始数组,它会返回一个新的、形状不同的数组。

浙公网安备 33010602011771号