kZjPBD.jpg

tensor.detach()

x = torch.tensor(2.0)
x.requires_grad_(True)
y = 2 * x
z = 5 * x

w = y + z.detach()
w.backward()

 

print(x.grad)

=> 2

 

本来应该x的梯度为7,但是detach()那一路切段了梯度的传播,导致5没有向后传递

posted @ 2022-04-03 21:16  Through_The_Night  阅读(63)  评论(0编辑  收藏  举报