11.11.2 学习率调度器

在 PyTorch 中,trainer.param_groups 是一个列表(list),其中每个元素是一个字典(dict),表示优化器管理的参数组(parameter group)。每个参数组包含一组模型参数及其对应的超参数(如学习率、动量等)。以下是详细解释:


1. trainer.param_groups[0] 是什么?

  • 默认参数组:当你初始化优化器时(如 torch.optim.SGD(net.parameters(), lr=0.1)),所有参数会被自动归为第一个参数组(索引为 0)。
  • 包含的变量
    • params:该组管理的模型参数(如 net.parameters() 返回的参数)。
    • lr:该组的学习率。
    • momentum:动量系数(如果优化器支持,如 SGD)。
    • weight_decay:权重衰减(L2 正则化系数)。
    • 其他优化器特定的超参数(如 Adam 中的 betas)。

示例

trainer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
print(trainer.param_groups[0].keys())
# 输出:dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'])

2. trainer.param_groups[1] 是什么?

  • 多参数组场景:默认情况下,优化器只有一个参数组(索引 0)。如果你在初始化优化器时显式定义了多个参数组,才会存在 param_groups[1]param_groups[2] 等。
  • 典型用途:对模型的不同部分(如不同层)设置不同的学习率或超参数。

示例

# 为模型的不同层设置不同的学习率
optimizer = torch.optim.SGD(
    [
        {'params': net.layer1.parameters(), 'lr': 0.1},  # 参数组 0
        {'params': net.layer2.parameters(), 'lr': 0.01}   # 参数组 1
    ],
    momentum=0.9
)
print(optimizer.param_groups[1]["lr"])  # 输出:0.01

3. 关键注意事项

  • 索引越界:如果优化器只有一个参数组,访问 param_groups[1] 会触发 IndexError
  • 动态调整超参数:可以通过修改 param_groups 中的值动态调整学习率等参数:
    # 将第一个参数组的学习率改为 0.001
    trainer.param_groups[0]["lr"] = 0.001
    
  • 分组技巧:常用于以下场景:
    • 微调预训练模型(冻结部分层,只训练部分层)。
    • 对不同层使用不同的学习率(如学习率衰减策略)。
    • 对不同参数应用不同的正则化系数。

4. 完整示例

import torch
import torch.nn as nn

# 定义一个简单模型
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 2)

net = Net()

# 创建优化器,分两个参数组
optimizer = torch.optim.SGD(
    [
        {'params': net.layer1.parameters(), 'lr': 0.1},  # 参数组 0
        {'params': net.layer2.parameters(), 'lr': 0.01}   # 参数组 1
    ],
    momentum=0.9
)

# 查看参数组内容
print("参数组 0 的学习率:", optimizer.param_groups[0]["lr"])  # 输出:0.1
print("参数组 1 的学习率:", optimizer.param_groups[1]["lr"])  # 输出:0.01

# 动态调整参数组 1 的学习率
optimizer.param_groups[1]["lr"] = 0.001
print("调整后的学习率:", optimizer.param_groups[1]["lr"])  # 输出:0.001

总结

  • param_groups[0] 是默认的参数组,包含所有未显式分组的参数。
  • param_groups[1] 及后续索引仅在显式定义多参数组时存在。
  • 通过 param_groups 可以灵活控制不同参数组的超参数,适用于复杂训练策略。
posted @ 2025-03-28 20:56  最爱丁珰  阅读(35)  评论(0)    收藏  举报