torch.nn.Flatten()

本文主要是将官网的解释翻译一下

官网解释:

torch.nn.Flatten(start_dim=1, end_dim=- 1)

将一个张量在连续维度内进行展平,用于Sequential

输入和输出形状:

  • 输入形状:

    \((*,S_{start},...,S_i,...,S_{end},*)\)

  • 输出形状:

    \((*,\prod^{end}_{i=start}S_i,*)\)

    解释一下:这个公式的意思是将开始维度和结束维度之间的所有维度进行相乘。

参数:

  • start_dim-需要展平的维度范围的开始,默认设置为1
  • end_dim-需要展平的维度范围的结束,默认设置为-1,也就是最后一个维度

官网例子:

>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
#第一个例子的意思就是将第1个维度(从维度0开始计算)到最后一个维度展平,也就是(32,1*5*5)


>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
#将第0个维度到第2个维度展平也就是(32*1*5,5)

自己再举一个小例子

import torch
from torch import nn as nn
X= torch.rand(1,32,64,64)
model = nn.Sequential(
    nn.Conv2d(in_channels=32,out_channels=128,stride=1,padding=1,kernel_size=3),
    nn.AvgPool2d(kernel_size=3,stride=2,padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(32*64*64,2),
)
out = model(X)
for layer in model:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
    
    
#输出结果为
#Conv2d output shape:	 torch.Size([1, 128, 64, 64])
#AvgPool2d output shape:	 torch.Size([1, 128, 32, 32])
#ReLU output shape:	 torch.Size([1, 128, 32, 32])
#Flatten output shape:	 torch.Size([1, 131072])
#Linear output shape:	 torch.Size([1, 2])

解释一下这个例子的意思:

首先是生成了一个(1,32,64,64)的张量,通过第一层卷积和平均池化层之后通道数翻倍,高宽减半,就变成了(1,128,32,32)的张量,经过nn.Flatten()之后,除了第0维以外,其他的维度被压缩

posted @ 2022-08-12 11:26  Zongxi_giegie  阅读(729)  评论(0)    收藏  举报