keepdims阻止降维
keepdims 是 NumPy 中一个非常实用的参数,它在对数组进行聚合操作(如 sum, mean, max 等)时,用来控制输出数组的维度。
keepdims 的作用
当你对一个数组进行聚合操作时,结果通常会减少一个维度。
例如,对一个 (3, 4) 的二维数组按行求和:
import numpy as np
arr = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 形状: (3, 4)
print(arr.shape)
# 按列求和 (axis=0)
# 结果形状: (4,),维度从2D降到1D
sum_by_col = np.sum(arr, axis=0)
print(sum_by_col.shape) # 输出: (4,)
# 按行求和 (axis=1)
# 结果形状: (3,),维度从2D降到1D
sum_by_row = np.sum(arr, axis=1)
print(sum_by_row.shape) # 输出: (3,)
keepdims=True 的作用就是阻止这个维度的下降。它会在输出结果中保留一个大小为1的维度。
# 按列求和 (axis=0),但保留维度
sum_by_col_keep = np.sum(arr, axis=0, keepdims=True)
print(sum_by_col_keep.shape) # 输出: (1, 4)
# 按行求和 (axis=1),但保留维度
sum_by_row_keep = np.sum(arr, axis=1, keepdims=True)
print(sum_by_row_keep.shape) # 输出: (3, 1)
为什么 keepdims 很重要?
keepdims 在机器学习中至关重要,因为它能确保**广播(Broadcasting)**操作的正确性。广播是 NumPy 的一个强大功能,它允许不同形状的数组进行数学运算。
例如,在你的代码中,计算偏置的梯度 delta_b 时:
delta_b = np.mean(G, axis=0, keepdims=True)
G的形状:(batch_size, 10)np.mean(G, axis=0)的形状是(10,)。keepdims=True使delta_b的形状变为(1, 10)。
为什么需要 (1, 10) 而不是 (10,)?因为在参数更新时,你需要用 delta_b 来更新偏置 b:
b = b - delta_b * lr
b的形状是(1, 10)。delta_b的形状是(1, 10)。
两者形状完全匹配,可以直接进行逐元素相减。
如果 delta_b 的形状是 (10,),NumPy 依然可以进行广播,但不保留维度会使代码更难理解,也可能在更复杂的运算中导致意想不到的错误。
结论:使用 keepdims=True 可以让你的代码在进行聚合运算时,保持数组的维度结构,这对于后续的矩阵运算(尤其是广播)来说,能大大提高代码的健壮性和可读性。

浙公网安备 33010602011771号