pytorch(05)计算图
张量的一系列操作,增多,导致可能出现多个操作之间的串行并行,协同不同的底层之间的协作,避免操作的冗余。计算图就是为了解决这些问题产生的。
计算图与动态图机制
1. 计算图
计算图用来描述运算的有向无环图,计算图有两个主要元素:结点Node和边Edge。
结点表示数据,如向量、矩阵、张量。
边表示运算,如加减乘除卷积、激活函数等
用计算图表示:y = (x + w)*(w+1)
a = x + w
b = w+ 1
y = a * b
从下往上进行。使用计算图的好处就是对于梯度求导比较方便
2. 计算图与梯度求导
y = (x+ w) * (w+1)
a = x+ w
b = w+1
y = a*b
\[\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}
\]
=b*1+a*1
=b+a
= (w+1)+(x+w)
=2*w+x+1
=2*1+2+1=5
3.叶子结点
用户创建的结点称为叶子结点,如x与 w
is_leaf:指示张量是否为叶子结点。他是整个计算图的根基
为什么要设立叶子结点,是为了节省内存,因为当程序结束后整个计算图的非叶子结点都是释放掉的
如果想使用非叶子结点的梯度,使用retain_grad(),执行这个方法就可以保留非叶子结点的梯度。
grad_fn:记录创建该张量时所用的方法(函数)在梯度传播的时候用到
比如y,y.grad_fn = <MulBackward0>y是用乘法得到的,所以在求解a和b的梯度时就会用乘法的求导法则去求解。叶子结点的grad_fn都是none
import torch
import numpy as np
x = torch.tensor([2.], requires_grad=True)
w = torch.tensor([1.], requires_grad=True)
a = torch.add(x,w)
a.retain_grad()
b = torch.add(w,1)
y = torch.mul(a,b)
y.backward()
# print(w.grad)
print("is_leaf:", x.is_leaf, w.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient",x.grad,w.grad,a.grad,b.grad,y.grad)
print("gradient_fun",x.grad_fn,w.grad_fn,a.grad_fn,b.grad_fn,y.grad_fn)
is_leaf: True True False False False
gradient tensor([2.]) tensor([5.]) tensor([2.]) None None
gradient_fun None None <AddBackward0 object at 0x00000214F45B6610> <AddBackward0 object at 0x00000214F65CC580> <MulBackward0 object at 0x00000214F65CC730>
浙公网安备 33010602011771号