Pytorch自动求导
PyTorch 的 autograd 是自动微分引擎的核心模块,它实现了反向传播的自动求导机制。以下是其关键概念和工作原理的详细说明:
1. 核心概念
1.1 Tensor 的 requires_grad 属性
- 作用:标记张量是否需要计算梯度。若为
True,PyTorch 会跟踪其运算历史。 - 示例:
x = torch.tensor([2.0], requires_grad=True) y = x ** 2 # y 的 grad_fn 是 `<PowBackward0>`
1.2 计算图(Computational Graph)
- 动态构建:每次张量操作(如加减乘除)会生成一个
Function节点(记录操作类型),并动态构建一个有向无环图(DAG)。 - 节点与边:
- 节点:张量(Tensor)或操作(Function)。
- 边:数据流向(输入→输出)和梯度反向传播路径(输出→输入)。
1.3 grad_fn
- 每个张量有一个
grad_fn属性,指向创建它的Function对象,保存了反向传播所需的计算逻辑。
2. 反向传播流程
2.1 触发条件
调用 loss.backward() 时,PyTorch 会从 loss 张量开始,沿着计算图反向传播梯度。
2.2 梯度计算
- 链式法则:梯度从输出端逐层传递到输入端,每个
Function节点根据其记录的运算类型计算局部梯度。 - 示例:
x = torch.tensor(3.0, requires_grad=True) y = x**2 + 2*x + 1 y.backward() # dy/dx = 2x + 2 = 8 print(x.grad) # 输出 tensor(8.)
2.3 梯度累积
- 特性:默认情况下,梯度会累积在
.grad属性中,而非覆盖。需手动清零(如optimizer.zero_grad())。
3. 关键方法
3.1 backward()
- 参数:
gradient:若输出是标量(如损失值),可省略;若为张量,需指定初始梯度(通常为全1)。retain_graph:是否保留计算图(用于多次反向传播)。
- 示例:
y = x.sum() y.backward(retain_graph=True) # 保留计算图
3.2 detach()
- 作用:从计算图中分离张量,阻止梯度跟踪。
- 用途:冻结模型部分参数或保存中间结果时节省内存。
z = y.detach() # z 的运算不再被跟踪
3.3 with torch.no_grad():
- 作用:上下文管理器,禁用梯度计算,常用于推理或更新参数。
with torch.no_grad(): prediction = model(input) # 不跟踪梯度
4. 注意事项
- 叶子节点与非叶子节点:
- 叶子节点:用户直接创建的张量(如
x = torch.tensor(...))。 - 梯度默认仅存储在叶子节点的
.grad中。
- 叶子节点:用户直接创建的张量(如
- 内存优化:
- 调用
backward()后,中间节点的梯度会被释放。若需保留,使用retain_grad():y.retain_grad() # 保留中间梯度
- 调用
- 高阶导数:
- 设置
create_graph=True可计算二阶导数:dy_dx = torch.autograd.grad(y, x, create_graph=True) d2y_dx2 = torch.autograd.grad(dy_dx, x)
- 设置
5. 性能与调试
- 优点:
- 动态图灵活性:支持条件分支、循环等复杂逻辑。
- 调试友好:可直接在 Python 中逐行检查中间结果。
- 缺点:
- 动态图构建开销可能影响性能(对比静态图框架如 TensorFlow)。
示例:完整流程
import torch
# 定义输入和参数
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)
# 前向计算
y = w * x + b # y = 2*1 + 3 = 5
loss = (y - 5)**2 # loss = (5-5)^2 = 0
# 反向传播
loss.backward()
# 查看梯度
print(w.grad) # d(loss)/dw = 2*(y-5)*x = 0 → tensor([0.])
print(x.grad) # d(loss)/dx = 2*(y-5)*w = 0 → tensor([0.])
通过 autograd,PyTorch 将梯度计算与模型训练解耦,使得实现复杂神经网络变得简洁高效。理解其底层机制有助于避免梯度错误(如未清零、未分离计算图)和优化性能。

浙公网安备 33010602011771号