PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例
来源 https://www.cnblogs.com/picassooo/p/13757403.html
PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例
变量.grad_fn表明该变量是怎么来的,用于指导反向传播。例如loss = a+b,则loss.gard_fn为<AddBackward0 at 0x7f2c90393748>,表明loss是由相加得来的,这个grad_fn可指导怎么求a和b的导数。
程序示例:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
import torchw1 = torch.tensor(2.0, requires_grad=True)a = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)tmp = a[0, :]tmp.retain_grad() # tmp是非叶子张量,需用.retain_grad()方法保留导数,否则导数将会在反向传播完成之后被释放掉b = tmp.repeat([3, 1])b.retain_grad()loss = (b * w1).mean()loss.backward()print(b.grad_fn) # 输出: <RepeatBackward object at 0x7f2c903a10f0>print(b.grad) # 输出: tensor([[0.3333, 0.3333], # [0.3333, 0.3333], # [0.3333, 0.3333]])print(tmp.grad_fn) # 输出:<SliceBackward object at 0x7f2c90393f60>print(tmp.grad) # 输出:tensor([1., 1.])print(a.grad) # 输出:tensor([[1., 1.], # [0., 0.]]) |
手动推导:

手动推导的结果和程序的结果是一致的。

浙公网安备 33010602011771号