pytorch-Flatten操作

1 class Flatten(nn.Module):
2     def __init__(self):
3         super(Flatten,self).__init__()
4         
5     def forward(self,input):
6         shape = torch.prod(torch.tensor(x.shape[1:])).item()
7         # -1 把第一个维度保持住
8         return x.view(-1,shape)

 

posted @ 2020-02-19 15:40  一大碗小米粥  阅读(3144)  评论(0编辑  收藏  举报