pytorch的argmax:只改变要挑选的维度
只改变要挑选的维度,其他维度不变
A=torch.tensor([[[3,4]]])
dec_X = A.argmax(dim=2)# 只在dim=2上挑选最大值,得到索引为scalar
dec_X #相当于把最内侧的[3,4]的维度去掉了,得到结果1。其他维度不变
tensor([[1]])
只改变要挑选的维度,其他维度不变
A=torch.tensor([[[3,4]]])
dec_X = A.argmax(dim=2)# 只在dim=2上挑选最大值,得到索引为scalar
dec_X #相当于把最内侧的[3,4]的维度去掉了,得到结果1。其他维度不变
tensor([[1]])