torch.nn.Flatten()
本文主要是将官网的解释翻译一下
torch.nn.Flatten(start_dim=1, end_dim=- 1)
将一个张量在连续维度内进行展平,用于Sequential
输入和输出形状:
-
输入形状:
\((*,S_{start},...,S_i,...,S_{end},*)\)
-
输出形状:
\((*,\prod^{end}_{i=start}S_i,*)\)
解释一下:这个公式的意思是将开始维度和结束维度之间的所有维度进行相乘。
参数:
- start_dim-需要展平的维度范围的开始,默认设置为1
- end_dim-需要展平的维度范围的结束,默认设置为-1,也就是最后一个维度
官网例子:
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
#第一个例子的意思就是将第1个维度(从维度0开始计算)到最后一个维度展平,也就是(32,1*5*5)
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
#将第0个维度到第2个维度展平也就是(32*1*5,5)
自己再举一个小例子
import torch
from torch import nn as nn
X= torch.rand(1,32,64,64)
model = nn.Sequential(
nn.Conv2d(in_channels=32,out_channels=128,stride=1,padding=1,kernel_size=3),
nn.AvgPool2d(kernel_size=3,stride=2,padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32*64*64,2),
)
out = model(X)
for layer in model:
X = layer(X)
print(layer.__class__.__name__,'output shape:\t', X.shape)
#输出结果为
#Conv2d output shape: torch.Size([1, 128, 64, 64])
#AvgPool2d output shape: torch.Size([1, 128, 32, 32])
#ReLU output shape: torch.Size([1, 128, 32, 32])
#Flatten output shape: torch.Size([1, 131072])
#Linear output shape: torch.Size([1, 2])
解释一下这个例子的意思:
首先是生成了一个(1,32,64,64)的张量,通过第一层卷积和平均池化层之后通道数翻倍,高宽减半,就变成了(1,128,32,32)的张量,经过nn.Flatten()之后,除了第0维以外,其他的维度被压缩

浙公网安备 33010602011771号