np.argmax


argmax 是 NumPy 和许多其他科学计算库(如 PyTorch、TensorFlow)中的一个非常常用的函数,它的作用是返回数组中最大值的索引

简单来说,argmax 告诉你最大值在哪里,而不是最大值是多少

argmax 的基本用法

np.argmax(a, axis=None, out=None)

  • a: 你要查找最大值索引的数组。
  • axis: (可选)指定在哪个维度上查找。
    • None(默认值):在整个数组中查找最大值的索引。
    • 一个整数:指定维度。例如,axis=0 表示按列查找,axis=1 表示按行查找。

1. 在一维数组中的应用

对于一维数组,argmax 会返回单个索引。

import numpy as np

scores = np.array([85, 90, 78, 92, 88])
# 数组中最大值是 92,它的索引是 3 (从 0 开始)
max_index = np.argmax(scores)

print(f"最大值的索引是: {max_index}")
# 输出: 最大值的索引是: 3
print(f"最大值是: {scores[max_index]}")
# 输出: 最大值是: 92

2. 在二维数组中的应用

在二维数组中,axis 参数变得非常重要。

import numpy as np

# 假设这是一个 3x4 的数组
grades = np.array([[80, 85, 90, 75],
                   [95, 88, 92, 90],
                   [70, 75, 80, 85]])

# 在整个数组中查找最大值的索引
# 先将数组展平为一维,再查找
overall_max_index = np.argmax(grades)
print(f"整个数组最大值的索引 (展平后): {overall_max_index}")
# 输出: 整个数组最大值的索引 (展平后): 4

# 按列查找最大值的索引 (axis=0)
max_per_column = np.argmax(grades, axis=0)
print(f"每列最大值的索引: {max_per_column}")
# 输出: 每列最大值的索引: [1 0 1 2]
# 解释: 第0列最大值95在索引1,第1列最大值85在索引0,以此类推。

# 按行查找最大值的索引 (axis=1)
max_per_row = np.argmax(grades, axis=1)
print(f"每行最大值的索引: {max_per_row}")
# 输出: 每行最大值的索引: [2 0 3]
# 解释: 第0行最大值90在索引2,第1行最大值95在索引0,以此类推。

在机器学习中的应用

argmax 在机器学习中非常常见,尤其是在多分类任务中。

一个典型的神经网络会对每个可能的类别输出一个分数概率。为了做出最终的预测,我们需要找出哪个类别的分数最高。

在你的手写数字识别代码中,np.argmax(pre, axis=1) 的作用就是:

  1. pre 是一个形状为 (1, 10) 的二维数组,其中每个值代表模型对 0-9 这 10 个数字的原始预测分数。
  2. axis=1 告诉 argmax 在这个数组的方向上查找最大值。
  3. 它返回最大值所在位置的索引。这个索引,就是模型最终预测的数字。

例如,如果 pre[[ -1.5, 0.2, 3.8, ... ]]np.argmax(pre, axis=1) 会返回 [2],这告诉我们模型预测的数字是 2。

posted @ 2025-09-16 16:58  李大嘟嘟  阅读(5)  评论(0)    收藏  举报