PyTorch中的optim.SGD

optim.SGD是PyTorch中的一个优化器,其实现了随机梯度下降(Stochastic Gradient Descent,SGD)算法。在深度学习中,我们通常使用优化器来更新神经网络中的内部参数,以使得损失函数尽可能地小。

在PyTorch中使用optim.SGD优化器,一般需要指定以下参数:

  • params:需要更新的参数,通常为模型中的权重和偏置项。
  • lr:学习率,即每次参数更新时的步长。
  • momentum:动量,用来加速模型收敛速度,避免模型陷入局部最优解。动量机制类似于小球从坡上滚下时积累惯性,能平滑更新路径,避免随机梯度下降(SGD)因梯度噪声引起的摆动或局部震荡。动量值通常设置为 0.9(默认为 0,即关闭动量),表示保留 90% 的上一次更新方向;值越大惯性越强,适合平坦区域但可能冲过最优值,值过小则接近普通 SGD。‌实际使用中,带动量时学习率常比纯 SGD 更小(如从 0.001 开始),以避免累积步长过大导致不稳定。‌
  • dampening:动量衰减,用来控制动量的衰减速度。
  • weight_decay:权重衰减,用来防止模型过拟合,即通过对权重的L2正则化来约束模型的复杂度。
  • nesterov:是否使用Nesterov动量。

在优化过程中,optim.SGD会根据当前的梯度和学习率计算出每个参数的更新量,并更新模型的参数。更新量的计算公式如下:

v(t) = momentum * v(t-1) - lr * (grad + weight_decay * w(t))
w(t) = w(t-1) + v(t)

其中,v(t)表示当前时刻的速度,v(t-1)表示上一个时刻的速度,grad表示当前时刻的梯度,w(t)表示当前时刻的权重,w(t-1)表示上一个时刻的权重。

optim.SGD算法中的动量(momentum)可以看作是一个惯性项,用来在参数更新时保留之前的状态。当梯度方向发生改变时,动量能够加速模型收敛,并降低震荡。Nesterov动量可以在动量的基础上进一步优化模型的性能,它会先根据上一个时刻的速度来计算下一个时刻的梯度,然后再更新参数。

需要注意的是:在使用optim.SGD时,要适当调整学习率和动量等超参数,以便在训练中达到更好的性能。

示例如下:

# 网络模型实例化
net = InceptionBlock(num_classes=2, in_channels=1, init_weights=True)
net.to(device)
# 选择损失函数
loss_function = nn.CrossEntropyLoss()
# 选择优化器
optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0.9)

epochs = 50  # 迭代次数
best_acc = 0.0  # 精度

# 网络结构保存路径
save_path = '../model/signal_classes.pth'

train_steps = len(train_loader)
for epoch in range(epochs):
    # 开启训练模式
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader)
    for step, data in enumerate(train_bar):
        signal_series, label_series = data
        signal_series = signal_series.reshape(batch_size, 1, 1200)
        # 优化器归零
        optimizer.zero_grad()
        # 模型计算
        model_predict = net(signal_series.to(device))
        # 计算损失函数
        loss = loss_function(model_predict, f.one_hot(label_series - 1, num_classes=2).float())
        # 反向传播计算
        loss.backward()
        # 优化器迭代
        optimizer.step()
        # 损失计算
        running_loss += loss.item()
        # 进度条计算
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

posted on 2025-12-30 11:10  『潇洒の背影』  阅读(1)  评论(0)    收藏  举报

导航