Pytorch的squeeze()和unsqueeze()

squeeze():压缩,对张量的维度进行减少的操作。

unsqueeze():扩充。

()中数字若为:正数,则在之前插入;负数,则在之后插入。

注:压缩或者扩充的维度为1

定义张量weights

1 weights = torch.tensor([0.2126, 0.7152, 0.0722])
2 weights.shape
3 
4 torch.Size([3])

对weights扩充维度unsqueeze(-1)

1 weights.unsqueeze(-1)
2 
3 tensor([[0.2126],
4         [0.7152],
5         [0.0722]])
1 weights.unsqueeze(-1).shape
2 
3 torch.Size([3, 1])

在上一步的基础上再扩充维度

1 weights.unsqueeze(-1).unsqueeze_(-1)
2 
3 tensor([[[0.2126]],
4 
5         [[0.7152]],
6 
7         [[0.0722]]])
1 weights.unsqueeze(-1).unsqueeze_(-1).shape
2 
3 torch.Size([3, 1, 1])

 

注:https://www.cnblogs.com/datasnail/p/13086803.html说的比较详细

posted @ 2020-12-11 22:40  vv_869  阅读(894)  评论(0编辑  收藏  举报