with torch.set_grad_enabled & requires_grad
with torch.set_grad_enabled(False):
x = torch.zeros(1, requires_grad=True)
print(x.requires_grad)
y=x*2
print(y.requires_grad)
True
False在with torch.set_grad_enabled中创建的张量requires_grad不受其影响(默认为False),如果经过运算会影响新张量的requires_grad.
with torch.set_grad_enabled(True):
print(torch.zeros((1, 1), requires_grad=True).requires_grad)
True
with torch.set_grad_enabled(True):
print(torch.zeros((1, 1), requires_grad=True).long().requires_grad)
False因为Pytorch中浮点类型的tensor才能有梯度,.long()会改变requires_grad

浙公网安备 33010602011771号