torch.argmax和argmin返回值

  在进行深度学习张量计算时,经常要获取张量在某个维度的最大值和最小值,以及这些值的位置。如果只需要知道位置,则torch.argmax和torch.argmin函数便可以实现。

Torch.argmax(input, dim=None, keepdim=False):返回指定维度最大值的序号。

  有时候返回的值比较难理解,所以这里直接放example以帮助理解:

 1 import torch
 2 
 3 t = torch.tensor([[1,2],[3,4],[2,8]])
 4 
 5 print(torch.argmax(t,0))
 6 
 7 
 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],        
 9                             [2,8,4],[7,5,3]]])
10 print(g)
11 print(torch.argmax(g,0))

先从简单的2维张量来看,t 是一个2维张量,大小为(3,2)。t 为 ,此时我们使dim=0,意思使求第0维的(即(3,2)中的3行)中的最大值的序号,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最终的结果为  tensor([1,2])

 


再来看一个3维张量g , tensor([[[1, 2, 3],

              [2, 3, 4],
              [5, 6, 7]],

              [[3, 4, 5],
              [7, 6, 5],
              [5, 4, 3]],

              [[8, 9, 0],
              [2, 8, 4],
              [7, 5, 3]]]),其大小为(3,3,3) 其中我们希望在dim=0的维度中求最大值的序号,则固定第一个维度,第一个维度为channel,则每个channel中对应位置进行比较。

比如每个channel中的(0,0)比较,1<3<8,所以得到的值为2;(0,1)比较,2<4<9,依然得到2,....以此类推。最终得到结果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。

 

posted @ 2020-07-12 19:47  ASTHNONT  阅读(3876)  评论(0编辑  收藏  举报