Loading

torch小知识

1.torch.eq

out = torch.eq(input, other)

  • 功能: 比较两个张量的元素是否相同
  • 参数:
    • input:输入的张量
    • other: 用于比较的张量
    • out: 输出元素为True或者是Flase的张量
  • 例子:
例子1:
>>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[ True, False],
        [False, True]])
例子2:
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))

2.torch.topk

(value, idx) = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

  • 功能: 沿输入数据的指定维度比较,返回k个最大的元素

  • 参数:

    • input:输入的张量
    • k: 取前top_k个数据
    • dim: 用于比较的维度
    • largest:设置为True,是返回最大的k个值;设置为Flase,是返回最小的k个值
    • sorted: 设置为True,返回的结果按照顺序返回
    • value: k个元素的值
    • idx: k个元素的索引值
  • 例子:

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))

参考:torch官方手册:https://pytorch.org/docs/stable/index.html

posted @ 2021-03-30 22:26  Guang'Jun  阅读(155)  评论(0)    收藏  举报