Pytorch自动求导

PyTorch 的 autograd 是自动微分引擎的核心模块,它实现了反向传播的自动求导机制。以下是其关键概念和工作原理的详细说明:


1. 核心概念

1.1 Tensor 的 requires_grad 属性

  • 作用:标记张量是否需要计算梯度。若为 True,PyTorch 会跟踪其运算历史。
  • 示例
    x = torch.tensor([2.0], requires_grad=True)
    y = x ** 2  # y 的 grad_fn 是 `<PowBackward0>`
    

1.2 计算图(Computational Graph)

  • 动态构建:每次张量操作(如加减乘除)会生成一个 Function 节点(记录操作类型),并动态构建一个有向无环图(DAG)。
  • 节点与边
    • 节点:张量(Tensor)或操作(Function)。
    • :数据流向(输入→输出)和梯度反向传播路径(输出→输入)。

1.3 grad_fn

  • 每个张量有一个 grad_fn 属性,指向创建它的 Function 对象,保存了反向传播所需的计算逻辑。

2. 反向传播流程

2.1 触发条件

调用 loss.backward() 时,PyTorch 会从 loss 张量开始,沿着计算图反向传播梯度。

2.2 梯度计算

  • 链式法则:梯度从输出端逐层传递到输入端,每个 Function 节点根据其记录的运算类型计算局部梯度。
  • 示例
    x = torch.tensor(3.0, requires_grad=True)
    y = x**2 + 2*x + 1
    y.backward()  # dy/dx = 2x + 2 = 8
    print(x.grad)  # 输出 tensor(8.)
    

2.3 梯度累积

  • 特性:默认情况下,梯度会累积在 .grad 属性中,而非覆盖。需手动清零(如 optimizer.zero_grad())。

3. 关键方法

3.1 backward()

  • 参数
    • gradient:若输出是标量(如损失值),可省略;若为张量,需指定初始梯度(通常为全1)。
    • retain_graph:是否保留计算图(用于多次反向传播)。
  • 示例
    y = x.sum()
    y.backward(retain_graph=True)  # 保留计算图
    

3.2 detach()

  • 作用:从计算图中分离张量,阻止梯度跟踪。
  • 用途:冻结模型部分参数或保存中间结果时节省内存。
    z = y.detach()  # z 的运算不再被跟踪
    

3.3 with torch.no_grad():

  • 作用:上下文管理器,禁用梯度计算,常用于推理或更新参数。
    with torch.no_grad():
        prediction = model(input)  # 不跟踪梯度
    

4. 注意事项

  1. 叶子节点与非叶子节点
    • 叶子节点:用户直接创建的张量(如 x = torch.tensor(...))。
    • 梯度默认仅存储在叶子节点的 .grad 中。
  2. 内存优化
    • 调用 backward() 后,中间节点的梯度会被释放。若需保留,使用 retain_grad()
      y.retain_grad()  # 保留中间梯度
      
  3. 高阶导数
    • 设置 create_graph=True 可计算二阶导数:
      dy_dx = torch.autograd.grad(y, x, create_graph=True)
      d2y_dx2 = torch.autograd.grad(dy_dx, x)
      

5. 性能与调试

  • 优点
    • 动态图灵活性:支持条件分支、循环等复杂逻辑。
    • 调试友好:可直接在 Python 中逐行检查中间结果。
  • 缺点
    • 动态图构建开销可能影响性能(对比静态图框架如 TensorFlow)。

示例:完整流程

import torch

# 定义输入和参数
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)

# 前向计算
y = w * x + b       # y = 2*1 + 3 = 5
loss = (y - 5)**2   # loss = (5-5)^2 = 0

# 反向传播
loss.backward()

# 查看梯度
print(w.grad)  # d(loss)/dw = 2*(y-5)*x = 0 → tensor([0.])
print(x.grad)  # d(loss)/dx = 2*(y-5)*w = 0 → tensor([0.])

通过 autograd,PyTorch 将梯度计算与模型训练解耦,使得实现复杂神经网络变得简洁高效。理解其底层机制有助于避免梯度错误(如未清零、未分离计算图)和优化性能。

posted @ 2025-04-30 10:11  程序员shaun  阅读(84)  评论(0)    收藏  举报