Pytorh之requires_grad
转载注明出处:PyTorch学习系列(十)--如何在训练时固定一些层? - CodeTutor - CSDN博客
Pytorch 保存和加载模型的代码:
保存模型: torch.save(net1.state_dict(), 'net_params.pkl') # 只保存神经网络的模型参数
与之对应的加载模型参数: model.load_state_dict(torch.load('net_params.pkl'))
在用户手动定义Variable时,参数requires_grad默认值是False。而在Module中的层在定义时,相关Variable的requires_grad参数默认是True。
在计算图中,如果有一个输入的requires_grad是True,那么输出的requires_grad也是True。只有在所有输入的requires_grad都为False时,输出的requires_grad才为False。
>>>x = Variable(torch.randn(2, 3), requires_grad=True) >>>y = Variable(torch.randn(2, 3), requires_grad=False) >>>z = Variable(torch.randn(2, 3), requires_grad=False) >>>out1 = x+y >>>out1.requires_grad True >>>out2 = y+z >>>out2.requires_grad Fals
在训练时如果想要固定网络的底层,那么可以令这部分网络对应子图的参数requires_grad为False。这样,在反向过程中就不会计算这些参数对应的梯度:
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters():#nn.Module有成员函数parameters() param.requires_grad = False # Replace the last fully-connected layer # Parameters of newly constructed modules have requires_grad=True by default model.fc = nn.Linear(512, 100)#resnet18中有self.fc,作为前向过程的最后一层。 # Optimize only the classifier optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)#optimizer用于更新网络参数,默认情况下更新所有的参数
volatile=True
Variable的参数volatile=True和requires_grad=False的功能差不多,但是volatile的力量更大。当有一个输入的volatile=True时,那么输出的volatile=True。volatile=True推荐在模型的推理过程(测试)中使用,这时只需要令输入的voliate=True,保证用最小的内存来执行推理,不会保存任何中间状态。
>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad #输出的requires_grad应该是True,因为中间层的Variable的requires_grad默认是True
True
>>> model(volatile_input).requires_grad#输出的requires_grad是False,因为输出的volatile是True(等价于requires_grad是False)
False
>>> model(volatile_input).volatile
True
instance:
def rsample(self, return_pretanh_value=False):
"""
Sampling in the reparameterization case.
"""
z = (
self.normal_mean +
self.normal_std *
Normal(
torch.zeros(self.normal_mean.size(), device=self.normal_mean.device),
torch.ones(self.normal_std.size(), device=self.normal_mean.device)
).sample()
)
z.requires_grad_()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
references:
https://zhuanlan.zhihu.com/p/30830840