torch.premute的介绍

permute进行的是置换。
permute的dim需要参数表示进行置换的维度。
import torch
import numpy as np
x=np.arange(24).reshape((2,3,4))
x = torch.tensor(x)
print(x)
y=x.permute((2, 1, 0))
print(y)
由代码可知列和批进行了置换。每一列都对应了一个新批,而批又转为了列。
输出:

tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
tensor([[[ 0, 12],
[ 4, 16],
[ 8, 20]],

[[ 1, 13],
[ 5, 17],
[ 9, 21]],

[[ 2, 14],
[ 6, 18],
[10, 22]],

[[ 3, 15],
[ 7, 19],
[11, 23]]], dtype=torch.int32)

 

import torch
import numpy as np
x=np.arange(24).reshape((2,3,4))
x = torch.tensor(x)
print(x)
y=x.permute((2, 0, 1))
permute(2,0,1)操作行变列,批变行,列变批。
print(y)

tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)

tensor([[[ 0, 4, 8],
[12, 16, 20]],

[[ 1, 5, 9],
[13, 17, 21]],

[[ 2, 6, 10],
[14, 18, 22]],

[[ 3, 7, 11],
[15, 19, 23]]], dtype=torch.int32)

 

 

 

 

 

 

 

 

 

posted @ 2022-12-24 07:44  祥瑞哈哈哈  阅读(193)  评论(0)    收藏  举报