5.1.2 顺序块

在自定义的 MySequential 类中,使用 self._modules 而不是列表来存储子模块的原因主要有以下两点:


1. 参数注册与跟踪

PyTorch 的 nn.Module 机制要求子模块必须被显式注册到父模块的 self._modules 字典中,这样才能被正确跟踪。具体来说:

  • 参数收集:当调用 model.parameters() 时,PyTorch 会自动收集所有注册到 self._modules 的子模块的参数。如果子模块未被注册(例如仅存储在普通列表中),其参数将不会被父模块识别,导致优化器无法更新这些参数。
  • 设备转移与状态管理:子模块的设备(如 CPU/GPU)和状态(如 training 模式)需要通过 self._modules 的注册机制来统一管理。未注册的子模块可能无法正确转移到 GPU 或进行其他状态同步。

示例

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.layers = list(args)  # ❌ 未注册到 self._modules

model = MySequential(nn.Linear(10, 20), nn.ReLU())
print(list(model.parameters()))  # 输出为空,因为参数未被注册

通过 self._modules 注册后:

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self._modules[str(idx)] = module  # ✅ 正确注册

model = MySequential(nn.Linear(10, 20), nn.ReLU())
print(len(list(model.parameters())))  # 输出 2(线性层的权重和偏置)

2. 顺序保证与遍历一致性

虽然列表本身是有序的,但 self._modules 是一个 OrderedDict,确保了以下两点:

  • 顺序一致性:在 forward 方法中遍历 self._modules.values() 时,顺序与构造函数中传入的 *args 完全一致。OrderedDict 的设计确保了插入顺序的可预测性。
  • 符合 PyTorch 的模块管理规范nn.Sequential 内部正是通过 OrderedDict 实现的,使用 self._modules 能确保与 PyTorch 的标准行为一致。

代码对比

# 使用列表存储(未注册)
self.layers = [nn.Linear(10, 20), nn.ReLU()]
# 遍历顺序虽然正确,但参数未被跟踪!

# 使用 self._modules(正确)
for idx, module in enumerate(args):
    self._modules[str(idx)] = module
# 顺序正确且参数被跟踪

总结

  • 必须注册到 self._modules:确保子模块的参数、设备、状态被父模块正确管理。
  • 顺序由 OrderedDict 保证:与列表的顺序一致,同时符合 PyTorch 的设计规范。

因此,直接使用列表存储子模块会导致参数无法被跟踪,而通过 self._modules 注册则解决了这一问题,同时保证了顺序的正确性。

是的,如果参数未被正确跟踪(即未被注册到父模块的 self._modules 中),那么 trainer.step()(或 optimizer.step())将无法更新这些参数。以下是详细解释和示例:


为什么参数无法被跟踪?

在 PyTorch 中,当子模块未被显式注册到父模块的 self._modules 中时,父模块的 parameters() 方法不会递归遍历这些子模块的参数。因此,这些子模块的参数不会被包含在父模块的参数列表中,导致优化器(如 SGDAdam)在初始化时无法获取到这些参数,最终无法更新它们。


关键问题:参数未被收集

假设你有一个自定义的 MySequential 类,但错误地将子模块存储在普通列表中(而非通过 self._modules 注册):

import torch
import torch.nn as nn

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.layers = list(args)  # ❌ 未注册到 self._modules,参数未被跟踪

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return X

# 创建模型
model = MySequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

# 获取参数
params = list(model.parameters())  # 输出为空列表!
print(params)  # 输出:[]

此时,model.parameters() 返回的参数列表是空的,因为子模块未被注册到 self._modules,导致父模块无法感知它们的存在。


trainer.step() 无法更新参数的原因

当使用优化器(如 torch.optim.SGD)时,优化器的参数列表是通过 model.parameters() 获取的:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

如果 model.parameters() 返回空列表,那么优化器的参数列表也是空的。此时,调用 optimizer.step() 时,没有任何参数会被更新,因为优化器根本不知道这些参数的存在。


正确实现:使用 self._modules 注册子模块

通过将子模块注册到 self._modules,父模块可以正确收集所有参数:

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self._modules[str(idx)] = module  # ✅ 正确注册到 self._modules

    def forward(self, X):
        for block in self._modules.values():
            X = block(X)
        return X

model = MySequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

params = list(model.parameters())  # 输出包含所有子模块的参数
print(len(params))  # 输出:4(两个线性层的权重和偏置)

此时,model.parameters() 会正确收集所有子模块的参数,优化器可以正常工作:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss.backward()
optimizer.step()  # ✅ 参数会被更新

总结

情况 参数是否被跟踪 optimizer.step() 是否更新参数
子模块未注册到 self._modules
子模块正确注册到 self._modules

关键点

  1. 参数注册机制:PyTorch 的 nn.Module 通过 self._modules 管理子模块,确保参数被正确收集。
  2. 优化器的参数列表:优化器初始化时会通过 model.parameters() 获取参数列表,未被收集的参数不会被更新。
  3. 自定义模块的正确实现:必须将子模块注册到 self._modules 中,否则它们的参数会被忽略。

通过这种方式,PyTorch 能够自动管理所有子模块的参数,确保训练时参数的正确更新。

课后题目的第一问说的是正常计算,计算是没问题的,但是不可以反向传播

posted @ 2025-03-12 12:44  最爱丁珰  阅读(32)  评论(0)    收藏  举报