pytorch detach函数
用于截断反向传播
detach()源码:
def detach(self): result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result
它的返回结果与调用者共享一个data tensor,且会将grad_fn设为None,这样就不知道该Tensor是由什么操作建立的,截断反向传播
这个时候再一个tensor使用In_place操作会导致另一个的data tensor也会发生改变
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) out = a.sigmoid() print(out)#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) c = out.detach() print(c)#tensor([0.7311, 0.8808, 0.9526])
这个时候可以看到,c和out的区别就是一个有grad_fn,一个没有grad_fn
执行out.sum().backward()没有问题,但执行c.sum().backward()报错:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
这个时候不论是对out还是对c进行inplace操作改变它们的data,这个改动会被autograd追踪,这个时候再执行out.sum().backward()会报错
假设对out进行inplace操作,会出现:
out.zero_() #tensor([0., 0., 0.], grad_fn=<ZeroBackward>) out.sum().backward() #报错
错误信息为
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
如果不对out进行inplace操作而是对c进行inplace操作,结果是一样的,Out不能再进行反向传播了
为了解决这种情况,就要对tensor的data操作,使其不被autograd记录
重新得到一个out,把它的data部分给c
c = out.data #tensor([0.7311, 0.8808, 0.9526]) out #tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
这里可以看到,c中没有Out中有的grad_fn信息
这回修改c的值,发现out的data值依然改了,但是执行out.sum().backward()不报错了
detach_()
def detach_(self): """Detaches the Variable from the graph that created it, making it a leaf. """ self._grad_fn = None self.requires_grad = False
做了两件事:1grad_fn设none2requires_grad设false
它不会新生成一个Variable而是使用原来的variable
 
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号