torch小知识
1.torch.eq
out = torch.eq(input, other)
- 功能: 比较两个张量的元素是否相同
- 参数:
- input:输入的张量
- other: 用于比较的张量
- out: 输出元素为True或者是Flase的张量
- 例子:
例子1:
>>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ True, False],
[False, True]])
例子2:
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))
2.torch.topk
(value, idx) = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)
-
功能: 沿输入数据的指定维度比较,返回k个最大的元素
-
参数:
- input:输入的张量
- k: 取前top_k个数据
- dim: 用于比较的维度
- largest:设置为True,是返回最大的k个值;设置为Flase,是返回最小的k个值
- sorted: 设置为True,返回的结果按照顺序返回
- value: k个元素的值
- idx: k个元素的索引值
-
例子:
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
参考:torch官方手册:https://pytorch.org/docs/stable/index.html