5.1.2 顺序块
在自定义的 MySequential 类中,使用 self._modules 而不是列表来存储子模块的原因主要有以下两点:
1. 参数注册与跟踪
PyTorch 的 nn.Module 机制要求子模块必须被显式注册到父模块的 self._modules 字典中,这样才能被正确跟踪。具体来说:
- 参数收集:当调用
model.parameters()时,PyTorch 会自动收集所有注册到self._modules的子模块的参数。如果子模块未被注册(例如仅存储在普通列表中),其参数将不会被父模块识别,导致优化器无法更新这些参数。 - 设备转移与状态管理:子模块的设备(如 CPU/GPU)和状态(如
training模式)需要通过self._modules的注册机制来统一管理。未注册的子模块可能无法正确转移到 GPU 或进行其他状态同步。
示例:
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
self.layers = list(args) # ❌ 未注册到 self._modules
model = MySequential(nn.Linear(10, 20), nn.ReLU())
print(list(model.parameters())) # 输出为空,因为参数未被注册
通过 self._modules 注册后:
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
self._modules[str(idx)] = module # ✅ 正确注册
model = MySequential(nn.Linear(10, 20), nn.ReLU())
print(len(list(model.parameters()))) # 输出 2(线性层的权重和偏置)
2. 顺序保证与遍历一致性
虽然列表本身是有序的,但 self._modules 是一个 OrderedDict,确保了以下两点:
- 顺序一致性:在
forward方法中遍历self._modules.values()时,顺序与构造函数中传入的*args完全一致。OrderedDict的设计确保了插入顺序的可预测性。 - 符合 PyTorch 的模块管理规范:
nn.Sequential内部正是通过OrderedDict实现的,使用self._modules能确保与 PyTorch 的标准行为一致。
代码对比:
# 使用列表存储(未注册)
self.layers = [nn.Linear(10, 20), nn.ReLU()]
# 遍历顺序虽然正确,但参数未被跟踪!
# 使用 self._modules(正确)
for idx, module in enumerate(args):
self._modules[str(idx)] = module
# 顺序正确且参数被跟踪
总结
- 必须注册到
self._modules:确保子模块的参数、设备、状态被父模块正确管理。 - 顺序由
OrderedDict保证:与列表的顺序一致,同时符合 PyTorch 的设计规范。
因此,直接使用列表存储子模块会导致参数无法被跟踪,而通过 self._modules 注册则解决了这一问题,同时保证了顺序的正确性。
是的,如果参数未被正确跟踪(即未被注册到父模块的 self._modules 中),那么 trainer.step()(或 optimizer.step())将无法更新这些参数。以下是详细解释和示例:
为什么参数无法被跟踪?
在 PyTorch 中,当子模块未被显式注册到父模块的 self._modules 中时,父模块的 parameters() 方法不会递归遍历这些子模块的参数。因此,这些子模块的参数不会被包含在父模块的参数列表中,导致优化器(如 SGD、Adam)在初始化时无法获取到这些参数,最终无法更新它们。
关键问题:参数未被收集
假设你有一个自定义的 MySequential 类,但错误地将子模块存储在普通列表中(而非通过 self._modules 注册):
import torch
import torch.nn as nn
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
self.layers = list(args) # ❌ 未注册到 self._modules,参数未被跟踪
def forward(self, X):
for layer in self.layers:
X = layer(X)
return X
# 创建模型
model = MySequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 1)
)
# 获取参数
params = list(model.parameters()) # 输出为空列表!
print(params) # 输出:[]
此时,model.parameters() 返回的参数列表是空的,因为子模块未被注册到 self._modules,导致父模块无法感知它们的存在。
trainer.step() 无法更新参数的原因
当使用优化器(如 torch.optim.SGD)时,优化器的参数列表是通过 model.parameters() 获取的:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
如果 model.parameters() 返回空列表,那么优化器的参数列表也是空的。此时,调用 optimizer.step() 时,没有任何参数会被更新,因为优化器根本不知道这些参数的存在。
正确实现:使用 self._modules 注册子模块
通过将子模块注册到 self._modules,父模块可以正确收集所有参数:
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
self._modules[str(idx)] = module # ✅ 正确注册到 self._modules
def forward(self, X):
for block in self._modules.values():
X = block(X)
return X
model = MySequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 1)
)
params = list(model.parameters()) # 输出包含所有子模块的参数
print(len(params)) # 输出:4(两个线性层的权重和偏置)
此时,model.parameters() 会正确收集所有子模块的参数,优化器可以正常工作:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss.backward()
optimizer.step() # ✅ 参数会被更新
总结
| 情况 | 参数是否被跟踪 | optimizer.step() 是否更新参数 |
|---|---|---|
子模块未注册到 self._modules |
否 | 否 |
子模块正确注册到 self._modules |
是 | 是 |
关键点
- 参数注册机制:PyTorch 的
nn.Module通过self._modules管理子模块,确保参数被正确收集。 - 优化器的参数列表:优化器初始化时会通过
model.parameters()获取参数列表,未被收集的参数不会被更新。 - 自定义模块的正确实现:必须将子模块注册到
self._modules中,否则它们的参数会被忽略。
通过这种方式,PyTorch 能够自动管理所有子模块的参数,确保训练时参数的正确更新。
课后题目的第一问说的是正常计算,计算是没问题的,但是不可以反向传播

浙公网安备 33010602011771号