初识pytorch:神经网络骨架nn.Module的使用

nn.Module

nn.Module的定义

是的,大家伙们,我们终于来到了正式定义神经网络的环节。首先我们要认识今天的内容,nn.Module(注意这里的nn是Neural network的缩写)。

nn.Module 是 PyTorch 中构建神经网络的核心基类,所有自定义网络层、模型组件或完整模型都需要继承这个类。它提供了统一的接口来管理网络参数、定义前向传播逻辑,并支持自动求导、模型保存与加载等关键功能。

在pytorch中,我们定义的所有网络模型首先就要继承nn.Module这个父类,这是后续所有功能可以正常实现的基础。

nn.Module的作用

1.参数管理
自动跟踪和管理网络中的可学习参数(权重、偏置等),无需手动定义和维护参数列表。例如,当你在 nn.Module 子类中定义 nn.Linear、nn.Conv2d 等层时,这些层的参数会被自动注册到模型的参数集合中,可通过 model.parameters() 访问。

2.前向传播标准化
要求子类必须实现 forward() 方法,统一了模型的前向计算逻辑。调用模型时(如 model(input)),会自动触发 forward() 方法,并在底层处理自动求导相关的钩子(hook)逻辑。

3.设备管理
提供 to(device) 方法,可将模型参数和缓冲区统一迁移到指定设备(如 CPU 或 GPU),简化跨设备计算流程。
注意这里的to方法可以用到tensor张量上,loss损失函数上,模型上,还有优化器上

4.模型嵌套与组合
支持在一个 nn.Module 中嵌套其他 nn.Module 实例,轻松构建复杂网络结构(如 “层→块→模型” 的层级结构)。

5.状态保存与加载
通过 state_dict() 方法获取模型参数的字典表示,通过 load_state_dict() 加载参数,方便模型的保存、加载和迁移学习。

nn.Module的使用

nn.Module的基本使用方法

使用nn.Module的步骤如下:

  • 1.继承 nn.Module 类

  • 2.在 init 中定义子模块(如卷积层、全连接层等)

  • 3.实现 forward() 方法,定义前向传播逻辑,示例代码如下(这里定义了一个简单的全连接网络):

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        # 调用父类构造函数
        super(SimpleNet, self).__init__()
        
        # 定义子模块(会自动注册参数)
        self.fc1 = nn.Linear(input_size, hidden_size)  # 全连接层1
        self.relu = nn.ReLU()                          # 激活函数
        self.fc2 = nn.Linear(hidden_size, output_size) # 全连接层2
    
    def forward(self, x):
        # 定义前向传播:输入x → fc1 → ReLU → fc2 → 输出
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 实例化模型
model = SimpleNet(input_size=784, hidden_size=256, output_size=10)
print(model)

输出的结果(模型结构):

SimpleNet(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

nn.Module的核心属性和方法

1.常用的属性

  • parameters():返回模型所有可学习参数的迭代器(用于优化器)。
    示例:optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  • children():返回直接子模块的迭代器(不包含嵌套子模块)。
  • modules():返回所有子模块的迭代器(包含嵌套子模块,用于遍历整个模型)。
  • state_dict():返回模型参数的字典(键为参数名,值为参数张量),用于保存模型。

2.常用的方法

  • forward(x):定义前向传播逻辑,必须在子类中实现。调用 model(x) 时自动执行。
  • to(device):将模型参数迁移到指定设备(如 model.to('cuda') 迁移到 GPU)。
  • eval():将模型切换到评估模式(禁用 Dropout、BatchNorm 等训练时的特殊行为)。
  • train():将模型切换到训练模式(启用 Dropout、BatchNorm 等)。
  • load_state_dict(state_dict):加载保存的模型参数。

nn.Module使用时的注意事项

1.不要在 forward() 中定义新模块
子模块必须在 init 中定义,否则参数无法被自动注册,导致梯度无法更新。

2.区分 eval() 和 train()
评估模型时必须调用 model.eval(),否则 Dropout 层会随机丢弃神经元,BatchNorm 会使用训练时的统计量,导致预测结果不稳定。

3.参数冻结与微调
迁移学习中,可通过 requires_grad 冻结部分参数,示例代码如下:

# 冻结前几层参数
for param in model.fc1.parameters():
    param.requires_grad = False

总结部分

nn.Module 是 PyTorch 构建神经网络的 “基石”,它通过统一的接口简化了模型定义、参数管理和训练流程。其核心优势在于:

  • 自动管理参数和梯度,无需手动操作;
  • 支持灵活的模型嵌套,轻松构建复杂网络;
  • 提供设备迁移、状态保存等实用功能,适配各种训练场景。
posted @ 2025-10-14 15:58  沃德天sama  阅读(11)  评论(0)    收藏  举报
1