【pytorch】土堆pytorch教程学习(六)神经网络的基本骨架——nn.module的使用
torch.nn 是 pytorch 的一个神经网络库(nn 是 neural network 的简称)。
Containers
torch.nn 构建神经网络的模型容器(Containers,骨架)有以下六个:
- Module
- Sequential
- ModuleList
- ModuleDict
- ParameterList
- ParameterDict
本文将介绍神经网络的基本骨架——nn.module的使用。
Module
所有神经网络模块的基类。自定义的模型也应该继承该类。
创建模型有两个要素:构建子模块和拼接子模块。构建子模块包括构建卷积层、池化层、全连接层等。拼接子模块即按照一定的顺序把构建好的子模块拼接起来。
自定义模型继承该类要重写 __init__() 和 forward():
- 在
__init__()里构建子模块,将子模块作为当前模块类的常规属性。一般将网络中具有可学习参数的层放在__init__中。 forward()前向传播函数,拼接子模块。
# 官方案例
import torch.nn as nn
import torch.nn.functional as F
# 自定义模型
class Model(nn.Module):
def __init__(self):
super().__init__() # 在对子类进行赋值之前,必须对父类进行__init__调用。
# 构建子模块
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
# 前向传播函数
def forward(self, x):
# 拼接子模块
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
# 模型调用
x = torch.randn(3, 1, 10, 20)
model = Model()
y = model(x)
为什么 forward() 方法能在model(x)时自动调用?
在 python 中当一个类定义了 __call__方法,则这个类实例就成为了可调用对象。而nn.Module 中的 __call__ 方法中调用了 forward() 方法,因此继承了 nn.Module 的子类对象就可以通过 model(x) 来调用 forward() 函数。
只要在 nn.Module 的子类中定义了 forward 函数,backward 函数就会被自动实现(利用Autograd)。
总结:自定义网络模型需要继承 nn.Module,并实现 __init__ 和 forward 函数。一个 Module 里可包含多个子 Module,比如 LeNet 是一个 Module,里面包括多个卷积层、池化层、全连接层等子 module。
本文来自博客园,作者:hzyuan,转载请注明原文链接:https://www.cnblogs.com/hzyuan/p/17384567.html

浙公网安备 33010602011771号