torch.nn.unfold && torch.nn.Fold

torch.nn.unfold

提取滑动窗口patches

1 inputs = torch.randn(1, 2, 4, 4)
2 print(inputs.size())
3 print(inputs)
4 unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
5 patches = unfold(inputs)
6 print(patches.size())
7 print(patches)
View Code

 

posted @ 2021-04-18 10:43  临近边缘  阅读(1018)  评论(0)    收藏  举报