Pytorch深度学习:自动微分

自动微分

根据设计好的模型,系统会构建一个计算图(computational graph), 来跟踪计算是哪些数据通过哪些操作组合起来产生输出。 自动微分使系统能够随后反向传播梯度。 反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

我们先看一个简单的例子:

函数y是一个标量函数

有一个列向量\(x=(x_1,x_2,x_3,x_4)^T\),函数:

\[y=2x^Tx=2(x_1^2+x_2^2+x_3^2+x_4^2) \]

那么我们想求解y关于x的梯度,怎么求呢?

import torch
x = torch.arange(4.0)
x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
# 此时x.grad默认值是None
y = 2 * torch.dot(x, x)
y.backward() #反向传播,计算y关于x的每个分量的梯度
print(x.grad) #梯度保存在这里,结果为:tensor([ 0.,  4.,  8., 12.])

如果我们还想计算x的另一个函数,那么先对x的梯度清0:

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
# x.grad为tensor([1., 1., 1., 1.])

函数y是向量时

当函数y不是标量时,向量y关于向量x的导数是一个矩阵。 求导的结果是一个高阶张量。

但一般我们调用向量的反向传播时,通常是计算一个批量里每个样本的偏导数之和。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x #y=(x1^2,...,x4^2),是一个向量
y.sum().backward()
x.grad

上面代码里y.sum().backward()也可以等价地写成:

y.backward(torch.ones(len(x)))

backward函数接受了一个叫做gradient的参数,当y是标量时不需要该参数,但如果是向量则必须传入参数gradient。该参数的作用是:

\(\textbf{y}\)\(\textbf{x}\)求导时,结果是一个梯度矩阵,为:

image-20230302154241836

当获取x的梯度时,有:

\[x.grad = \frac{\partial \textbf{y}}{\partial\textbf{x}} \times gradient \]

参考:[backward函数中gradient参数的一些理解](https://www.cnblogs.com/meitiandouyaokaixin/p/16339669.html#:~:text=backward函数中gradient参数的一些理解 当标量对向量求导时不需要该参数,但当向量对向量求导时,若不加上该参数则会报错,显示“grad can be implicitly created only,for scalar outputs”,对该gradient参数解释如下。 当 y 对 x 求导时,结果为梯度矩阵,数学表达如下:)

分离计算

如果我们希望把某些计算移动到计算图之外,那么可以采用分离计算来实现:

x = torch.arange(4.0)
y = x * x 		# y = (x1^2,...,x4^2) = (1,4,9,16)
u = y.detach()  # u = (1,4,9,16)
z = u * x 		# z = (x1, 4*x2, 9*x3, 16*x4)
z.sum().backward() 
# 结果:x.grad 和 u 相等

u就是y移除计算图的变量,只是一个常数张量。

注意,此时u.requires_grad为False,也可以通过u.requires_grad_(True)将其设置为叶子节点,不过它就相当于一个刚创建的张量,之前的计算图和它没有关系。

注意事项

  1. 当尝试输出非叶子节点的梯度时

对于如下代码:

x = torch.arange(4.0, requires_grad=True)
y = x * x
z = y * x
z.sum().backward()
print(y.grad) 
# 结果为None

当尝试输出叶子节点y的梯度值时,会报出warning,警告不要获取非叶子节点的梯度,并且返回None。

  1. 当尝试两次调用backward函数时:

对于上面的代码,我们连续调用:

z.sum().backward()
z.sum().backward()

会报错,报错信息显示Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().

  1. 好文要顶

一文解释 PyTorch求导相关 (backward, autograd.grad) - 知乎 (zhihu.com)这篇文章写的太好了。

posted @ 2023-03-02 16:22  KouweiLee  阅读(70)  评论(1编辑  收藏  举报