PyTorch-权重衰退(Weight Decay)

PyTorch Weight Decay

目录

  1. 摘要
  2. 概念与理论
    • 2.1 核心概念
    • 2.2 与 L2 正则化的关系
    • 2.3 核心作用
  3. PyTorch 实践
    • 3.1 如何设置 λ(权重衰减系数)
    • 3.2 不同架构的常见设置
    • 3.3 PyTorch 实现方式
    • 3.4 高级技巧

1. 摘要

Weight Decay(权重衰减)是深度学习中重要的正则化技术,通过在训练过程中对模型权重施加惩罚,防止过拟合,提升模型泛化能力。

2. 概念与理论

2.1 核心概念

Weight Decay是一种正则化技术,在损失函数中添加与权重大小相关的惩罚项,鼓励模型学习更小的权重值,得到更简单、平滑的模型。

带Weight Decay的总损失函数:

L_total = L_original + λ/2 * ||w||²

其中λ是权重衰减系数,控制惩罚项权重:λ越大,对大幅值权重的惩罚越重,模型越简单。

2.2 与 L2 正则化的关系

在标准随机梯度下降(SGD)中,Weight Decay完全等价于L2正则化。

但在使用自适应优化器(如Adam, AdamW)时,传统实现方式会导致不等价。Adam等优化器会为每个参数计算自适应学习率,如果直接将L2正则项加到损失函数中,会像处理普通梯度一样处理正则项的梯度,导致正则化效果被扭曲。

AdamW(Adam with Weight Decay)解决了这个问题,将Weight Decay项从损失函数中解耦出来,直接在权重更新时添加,而不影响梯度计算。

AdamW的更新规则:
w = w - lr * d(L_original)/dw - lr * λ * w

关键区别:AdamW中的λ * w项不参与梯度、一阶矩、二阶矩的计算,是独立的衰减项,效果更纯粹稳定。

2.3 核心作用

防止过拟合:通过惩罚大的权重,限制模型复杂度,使其无法完美"记忆"训练数据中的噪声和细节。

提升泛化能力:更简单的模型在未见过的数据上通常表现更好。

3. PyTorch 实践

3.1 如何设置 λ(权重衰减系数)

λ是关键超参数,需要仔细调整。没有通用值。

典型范围:λ通常在1e-4到1e-2之间(0.0001到0.01)。

  • 1e-4是常用且安全的起始点
  • 1e-3和1e-4是最常见的选择
  • 1e-2是非常强的衰减,只适用于特定场景

调整策略

  • 从默认值开始:λ = 1e-4或1e-3
  • 与学习率协同调整:通常需要将两者一起搜索
  • 观察训练与验证曲线:
    • 欠拟合(训练误差和验证误差都很大):减小λ或设为0
    • 过拟合(训练误差很小,验证误差很大):增大λ

3.2 不同架构的常见设置

计算机视觉(CNN):常用1e-4量级。ResNet、VGG等经典网络通常使用此值。

自然语言处理(Transformer):AdamW是标准优化器。常用值为0.01或0.1。

其他领域:RNN/LSTM通常从1e-4开始尝试。

3.3 PyTorch 实现方式

方式一:使用SGD优化器

optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            momentum=0.9,
                            weight_decay=1e-4)

方式二:使用AdamW优化器(推荐)

optimizer = torch.optim.AdamW(model.parameters(),
                              lr=1e-4,
                              weight_decay=0.01)

注意:避免使用Adam + L2,会导致自适应学习率问题。

3.4 高级技巧

不对偏置和归一化层进行衰减

只对权重应用Weight Decay,不对偏置和层归一化、批归一化参数应用。

# 示例:将权重和偏置参数分开
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
    if any(nd in name for nd in ["bias", "norm.weight", "norm.bias"]):
        # 偏置和Norm层的参数不衰减
        no_decay_params.append(param)
    else:
        # 其他权重参数衰减
        decay_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.01},
    {'params': no_decay_params, 'weight_decay': 0.0}
], lr=1e-4)
posted @ 2025-09-19 19:10  aaooli  阅读(114)  评论(0)    收藏  举报