Pytorch中的Sort的使用

>>> a = torch.randn(3,3)
>>> a
tensor([[ 0.5805, 0.1940, 1.2591],
[-0.0863, 0.5350, -0.7737],
[-0.4059, -0.0447, -0.3434]])
>>> a.sort(0,True)[0]
tensor([[ 0.5805, 0.5350, 1.2591],
[-0.0863, 0.1940, -0.3434],
[-0.4059, -0.0447, -0.7737]])
>>> a.sort(0,False)[0]
tensor([[-0.4059, -0.0447, -0.7737],
[-0.0863, 0.1940, -0.3434],
[ 0.5805, 0.5350, 1.2591]])
>>> a.sort(0,True)[1]
tensor([[0, 1, 0],
[1, 0, 2],
[2, 2, 1]])
>>> a.sort(1,True)[0]
tensor([[ 1.2591, 0.5805, 0.1940],
[ 0.5350, -0.0863, -0.7737],
[-0.0447, -0.3434, -0.4059]])
>>> a = torch.randn(3,3)
>>> a
tensor([[ 0.6073, -0.7748, -1.4459],
[ 0.8176, -0.9419, 1.2187],
[ 0.0301, -0.2075, -1.2473]])
>>> a.sort(1,True)[0]
tensor([[ 0.6073, -0.7748, -1.4459],
[ 1.2187, 0.8176, -0.9419],
[ 0.0301, -0.2075, -1.2473]])
>>>

>>> a.max(0)
torch.return_types.max(
values=tensor([ 0.8176, -0.2075, 1.2187]),
indices=tensor([1, 2, 1]))
>>> a.max(1)
torch.return_types.max(
values=tensor([0.6073, 1.2187, 0.0301]),
indices=tensor([0, 2, 0]))
>>>()中的0.1是行和列的区别。。。。

posted @ 2021-01-19 12:25  _八级大狂风  阅读(553)  评论(0编辑  收藏  举报