torch.max函数在分类问题中的使用 学习

适用于在pytorch的张量上,求某一维度的最大值。
一般在模型测试阶段,求模型预测输出类别的时候使用。

假设是10分类问题,比如mnist
对于一个批次的输入 images 将它传入net(images)
会得到输出out(bs,10) 但是第二个维度,仅仅是模型在十个类别的预测值
需要取最大值,才能得到预测结果。
对于这个问题来说,最大值的索引即为预测结果。

torch.max(input,dim)
输入为一个张量,以及指定dim 在哪一个维度上求max
输出 为values,indices
values为 一个存储最大值实际值的张量
indices为存储最大值索引的张量

则对于分类问题来说,indices是我们需要的

所以常用的代码为:

_,predicted=torch.max(out,dim=1)
posted @ 2025-09-24 13:00  朱朱成  阅读(8)  评论(0)    收藏  举报