pytorch loss.backward()核心理解 和lstm中hidden的典型错误
X = torch.ones(2, 2, requires_grad=False)
w = torch.ones(2, 2, requires_grad=True)
c1 = X*w
c1.backward()
print(w.grad)
# 会报错
c2 = c1 * X
c2.backward()
#重新建立c1的链子
c1=X*w
c2 = c1 * X
c2.backward()
![]()


浙公网安备 33010602011771号