keepdims阻止降维

keepdims 是 NumPy 中一个非常实用的参数,它在对数组进行聚合操作(如 sum, mean, max 等)时,用来控制输出数组的维度


keepdims 的作用

当你对一个数组进行聚合操作时,结果通常会减少一个维度

例如,对一个 (3, 4) 的二维数组按行求和:

import numpy as np
arr = np.array([[1, 2, 3, 4],
                [5, 6, 7, 8],
                [9, 10, 11, 12]])

# 形状: (3, 4)
print(arr.shape)

# 按列求和 (axis=0)
# 结果形状: (4,),维度从2D降到1D
sum_by_col = np.sum(arr, axis=0)
print(sum_by_col.shape) # 输出: (4,)

# 按行求和 (axis=1)
# 结果形状: (3,),维度从2D降到1D
sum_by_row = np.sum(arr, axis=1)
print(sum_by_row.shape) # 输出: (3,)

keepdims=True 的作用就是阻止这个维度的下降。它会在输出结果中保留一个大小为1的维度。

# 按列求和 (axis=0),但保留维度
sum_by_col_keep = np.sum(arr, axis=0, keepdims=True)
print(sum_by_col_keep.shape) # 输出: (1, 4)

# 按行求和 (axis=1),但保留维度
sum_by_row_keep = np.sum(arr, axis=1, keepdims=True)
print(sum_by_row_keep.shape) # 输出: (3, 1)

为什么 keepdims 很重要?

keepdims 在机器学习中至关重要,因为它能确保**广播(Broadcasting)**操作的正确性。广播是 NumPy 的一个强大功能,它允许不同形状的数组进行数学运算。

例如,在你的代码中,计算偏置的梯度 delta_b 时:
delta_b = np.mean(G, axis=0, keepdims=True)

  • G 的形状(batch_size, 10)
  • np.mean(G, axis=0) 的形状是 (10,)
  • keepdims=True 使 delta_b 的形状变为 (1, 10)

为什么需要 (1, 10) 而不是 (10,)?因为在参数更新时,你需要用 delta_b 来更新偏置 b
b = b - delta_b * lr

  • b 的形状(1, 10)
  • delta_b 的形状(1, 10)

两者形状完全匹配,可以直接进行逐元素相减。

如果 delta_b 的形状是 (10,),NumPy 依然可以进行广播,但不保留维度会使代码更难理解,也可能在更复杂的运算中导致意想不到的错误

结论:使用 keepdims=True 可以让你的代码在进行聚合运算时,保持数组的维度结构,这对于后续的矩阵运算(尤其是广播)来说,能大大提高代码的健壮性和可读性。

posted @ 2025-09-19 15:54  李大嘟嘟  阅读(8)  评论(0)    收藏  举报