torch.max
第一个版本:
torch.max(input) → Tensor
Returns the maximum value of all elements in the input tensor.
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.6763, 0.7445, -2.2369]]) >>> torch.max(a) tensor(0.7445)
第二个版本:
torch.max(input, dim, keepdim=False, *, out=None)
- Returns a namedtuple
(values, indices),wherevaluesis the maximum value of each row of theinputtensor in the given dimensiondim.- And
indicesis the index location of each maximum value found (argmax).
- If
keepdimisTrue, the output tensors are of the same size asinputexcept in the dimensiondimwhere they are of size 1. Otherwise,dimis squeezed (seetorch.squeeze()), resulting in the output tensors having 1 fewer dimension thaninput.
If there are multiple maximal values in a reduced row then the indices of the first maximal value are returned.
Parameters
-
input (Tensor) – the input tensor.
-
dim (int) – the dimension to reduce.
-
keepdim (bool) – whether the output tensor has
dimretained or not. Default:False.
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
浙公网安备 33010602011771号