【基本语法】PyTorch 中 __init__ 和 forward 方法的调用机制详解
PyTorch 中 __init__ 和 forward 方法的调用机制详解
在 PyTorch 的神经网络模块中,__init__ 和 forward 是两个核心方法,它们协同工作但调用时机和方式完全不同。下面通过完整示例和原理分析来详细说明它们的调用机制。
完整示例代码
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleModel, self).__init__()
print("__init__ 方法被调用 - 初始化模型结构")
# 定义网络层
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
# 初始化权重
self._initialize_weights()
def _initialize_weights(self):
print("内部方法 _initialize_weights 被调用")
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.zeros_(self.fc1.bias)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc2.bias)
def forward(self, x):
print("forward 方法被调用 - 执行前向传播")
# 前向传播流程
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 1. 实例化模型 (调用 __init__)
print("创建模型实例...")
model = SimpleModel(input_size=10, hidden_size=5, output_size=2)
print("\n模型创建完成\n")
# 2. 创建输入数据
input_data = torch.randn(3, 10) # batch_size=3, 特征数=10
print("输入数据形状:", input_data.shape)
# 3. 调用模型 (隐式调用 forward)
print("\n执行模型推理...")
output = model(input_data) # 这里调用 forward 方法
print("输出数据形状:", output.shape)
print("输出数据:\n", output.detach().numpy())
# 4. 显式调用 forward (不推荐)
print("\n显式调用 forward...")
direct_output = model.forward(input_data)
print("直接输出形状:", direct_output.shape)
执行流程与输出
创建模型实例...
__init__ 方法被调用 - 初始化模型结构
内部方法 _initialize_weights 被调用
模型创建完成
输入数据形状: torch.Size([3, 10])
执行模型推理...
forward 方法被调用 - 执行前向传播
输出数据形状: torch.Size([3, 2])
输出数据:
[[-0.112 0.543]
[ 0.234 -0.876]
[-0.765 0.432]]
显式调用 forward...
forward 方法被调用 - 执行前向传播
直接输出形状: torch.Size([3, 2])
关键机制解析
1. __init__ 方法的调用
- 调用时机:在创建模型实例时自动调用
- 主要作用:
- 定义网络层结构(如 Linear, Conv2d 等)
- 初始化权重参数
- 注册子模块
- 调用方式:
model = MyModel(params) # 自动触发 __init__
2. forward 方法的调用
- 调用时机:在模型被调用时自动触发
- 主要作用:
- 定义数据流经网络的路径
- 实现前向传播逻辑
- 返回计算结果
- 调用方式:
output = model(input_data) # 推荐方式
3. 底层原理:__call__ 方法
PyTorch 的 nn.Module 类实现了特殊的 __call__ 方法,这是 forward 方法自动调用的关键:
class Module:
def __call__(self, *input, **kwargs):
# ... 前置处理 (如 hook 注册) ...
result = self.forward(*input, **kwargs)
# ... 后置处理 ...
return result
当执行 model(input_data) 时:
- Python 解释器调用
model.__call__(input_data) __call__方法内部调用self.forward(input_data)- 返回
forward方法的结果
4. 两种调用方式的区别
| 调用方式 | 代码示例 | 是否推荐 | 说明 |
|---|---|---|---|
| 隐式调用 | output = model(input_data) |
✓ 推荐 | 通过 __call__ 触发,包含完整的前后处理 |
| 显式调用 | output = model.forward(input_data) |
✗ 不推荐 | 直接调用,绕过 PyTorch 的内部处理机制 |
5. 重要注意事项
-
不要直接调用
forward:- 绕过 PyTorch 的钩子(hooks)系统
- 可能影响梯度计算和分布式训练
- 丢失模型状态跟踪功能
-
__init__中的初始化顺序:class CorrectModel(nn.Module): def __init__(self): super().__init__() # 必须先调用 super().__init__() self.layer = nn.Linear(10, 5) # 然后定义子模块 -
模块注册机制:
- 在
__init__中定义的子模块会自动注册 - PyTorch 可以跟踪所有参数:
for param in model.parameters(): print(param.shape)
- 在
-
前向传播中的动态行为:
def forward(self, x): # 可以根据输入动态调整结构 if x.shape[1] > 10: x = self.extra_layer(x) return self.main_layer(x)
实际应用场景
训练循环中的调用
model = MyModel().train() # 调用 __init__
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for data, target in dataloader:
optimizer.zero_grad()
output = model(data) # 调用 forward
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
自定义前向传播逻辑
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
residual = x
x = torch.relu(self.conv1(x))
x = self.conv2(x)
x += residual # 残差连接
return torch.relu(x)
总结
-
__init__方法:- 在创建模型实例时自动调用一次
- 用于定义网络结构和初始化参数
-
forward方法:- 通过
model(input_data)隐式调用 - 每次推理/训练时执行
- 定义数据流和计算逻辑
- 通过
-
最佳实践:
- 始终通过模型实例调用 (
model(input_data)) 而不是直接调用forward - 在
__init__中定义所有持久性组件 - 在
forward中保持计算逻辑的纯净性
- 始终通过模型实例调用 (
理解这两个方法的调用机制是有效使用 PyTorch 构建神经网络的基础,它们共同定义了模型的结构和行为,同时确保 PyTorch 能够正确管理计算图和自动微分。

浙公网安备 33010602011771号