transpose()和permute()

pytorch转置用的函数就只有这两个:transpose()permute(),本文将详细地介绍这两个函数以及它们之间的区别。

 transpose()

torch.transpose(input, dim0, dim1, out=None) → Tensor

函数返回输入矩阵input的转置。交换维度dim0dim1

参数:

  • input (Tensor) – 输入张量,必填
  • dim0 (int) – 转置的第一维,默认0,可选
  • dim1 (int) – 转置的第二维,默认1,可选

注意只能有两个相关的交换的位置参数。

例子:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893,  0.5809],
        [-0.1669,  0.7299,  0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
        [-0.9893,  0.7299],
        [ 0.5809,  0.4942]])

 

permute()

参数:
dims (int…*)-换位顺序,必填

 例子:

>>> x = torch.randn(2, 3, 5) 
>>> x.size() 
torch.Size([2, 3, 5]) 
>>> x.permute(2, 0, 1).size() 
torch.Size([5, 2, 3])

 

transpose与permute的异同

  • permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度;
  • torch.transpose(x)合法, x.transpose()合法。torch.permute(x)不合法,x.permute()合法。
  • 与contiguous、view函数之关联。contiguous:view只能作用在contiguous的variable上,如果在view之前调用了transpose、permute等,就需要调用contiguous()来返回一个contiguous copy;一种可能的解释是:有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式;判断ternsor是否为contiguous,可以调用torch.Tensor.is_contiguous()函数:
    import torch 
    x = torch.ones(10, 10) 
    x.is_contiguous()                                 # True 
    x.transpose(0, 1).is_contiguous()                 # False
    x.transpose(0, 1).contiguous().is_contiguous()    # True
    

      另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),与 numpy.reshape() 的功能类似,大致相当于 tensor.contiguous().view(),这样就省去了对tensor做view()变换前,调用contiguous()的麻烦;

  

  

posted on 2021-11-21 20:33  朴素贝叶斯  阅读(1210)  评论(0编辑  收藏  举报

导航