删除tensor中不是1的维度

squeeze只能删除是1的维度,因此对不是1的维度先切片为1,然后再删除

input_torch=torch.randn([1,5,1,2])
print(input_torch[:, :, :1, :1].squeeze(dim=(2, 3)).shape) #先切片为[1,5,1,1],再删除后两维

 

posted @ 2024-11-26 00:31  夕西行  阅读(29)  评论(0)    收藏  举报