nn.Linear输入输出维度的理解
目录
最近写代码发现nn.Linear的输入维度其实可以不是一维的输入。
一般来说,nn.Linear的输入一般是(batch_size, input_dimen)。但是官方文档说明了:
input:(,input_dimen)
output:(,output_dimen)
因此其实*是可以任意维度的。
但是要注意的是,无论*的维度怎么样,真正的输入层和输出层神经元的个数和结构,是由input_dimen和output_dimen决定的,因此神经网络的参数也是其决定的。
也就是说:无论的维度是多少,输入的数据都是相当于batch_size的作用,公用一套神经网络参数。
官方文档的解释:
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

本文来自博客园,作者:JaxonYe,转载请注明原文链接:https://www.cnblogs.com/yechangxin/articles/18213439
侵权必究

浙公网安备 33010602011771号