数学相关
考虑这段代码:
x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u
这里u = y.detach()
意味着不将u看作一个变量组成的向量,而仅将 u 看作一个数字组成的向量,即u = [0, 1, 4, 9]
,故z = u * x
即为z = [0 * x1, 1 * x2, 4 * x3, 9 * x4]
,故z.sum().backward()
执行后,x.grad
为[0, 1, 4, 9]