torch.squeeze

torch.squeeze(A,N)

如果不指定位置参数N,如果数组A的维度为(1,1,3)那么执行 torch.squeeze(A,1) 后A的维度变为 (1,3),中间的维度被删除

注:

1. 如果指定的维度大于1,那么将操作无效

2. 如果不指定维度N,那么将删除所有维度为1的维度

torch.unsqueeze(A,N)

torch.unsqueeze()函数的作用增加数组A指定位置N的维度,例如两行三列的数组A维度为(2,3),那么这个数组就有三个位置可以增加维度,分别是( [位置0] 2,[位置1] 3 [位置2] )或者是 ( [位置-3] 2,[位置-2] 3 [位置-1] ),如果执行 torch.unsqueeze(A,1),数据的维度就变为了 (2,1,3)

代码演示:

1. squeeze函数

a=torch.randn(1,1,3)
print(a.shape)
b=torch.squeeze(a)
print(b.shape)
c=torch.squeeze(a,0)
print(c.shape)
d=torch.squeeze(a,1)
print(d.shape)
e=torch.squeeze(a,2)#如果去掉第三维,则数不够放了,所以直接保留
print(e.shape)

输出:

torch.Size([1, 1, 3])
torch.Size([3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 1, 3])

2. unsqueeze函数

a=torch.randn(1,3)
print(a.shape)
b=torch.unsqueeze(a,0)
print(b.shape)
c=torch.unsqueeze(a,1)
print(c.shape)
d=torch.unsqueeze(a,2)
print(d.shape)

输出

torch.Size([1, 3])
torch.Size([1, 1, 3])
torch.Size([1, 1, 3])
torch.Size([1, 3, 1])
posted @ 2022-05-14 13:07  今天记笔记了吗  阅读(489)  评论(0)    收藏  举报