with torch.set_grad_enabled & requires_grad

with torch.set_grad_enabled(False):
    x = torch.zeros(1, requires_grad=True)
    print(x.requires_grad)
    y=x*2
    print(y.requires_grad)
    
True
False

在with torch.set_grad_enabled中创建的张量requires_grad不受其影响(默认为False),如果经过运算会影响新张量的requires_grad.

with torch.set_grad_enabled(True):
    print(torch.zeros((1, 1), requires_grad=True).requires_grad)
    
True

with torch.set_grad_enabled(True):
    print(torch.zeros((1, 1), requires_grad=True).long().requires_grad)
    
False

因为Pytorch中浮点类型的tensor才能有梯度,.long()会改变requires_grad

posted @ 2021-06-17 22:28  思所匪夷  阅读(115)  评论(0)    收藏  举报  来源