1. torch有两种方法,分别是设置requires_grad=False或者只将训练参数传入optimizer
(1)法1 :设置requres_grad=Fase ,即可以在网络内部设置,也可以在传入optimizer之前根据名称设置
在网络内部设置方法如下:

在传入optimizer之前根据名称设置:

(2)法2只将训练的参数传入optimizer

2. tensorflow更新部分权重(根据名称过滤):

附torch代码:
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self,inchannel,outchannel):
super(Net,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inchannel,64,kernel_size=3,padding=1,stride=2),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(64,256,kernel_size=3,padding=1,stride=2),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv3 = nn.Conv2d(256,512,kernel_size=3,padding=1,stride=2)
self.conv4 = nn.Conv2d(512,512,kernel_size=3,padding=1,stride=2)
self.avg = nn.AdaptiveAvgPool2d(1)
self.linear1 = nn.Sequential(
nn.Linear(512,256),
nn.ReLU()
)
#设置requires_grad=False
# 在网络定义中设置,该代码前面定义的参数不更新
# for p in self.parameters():
# p.requires_grad = False
self.linear2 = nn.Linear(256,outchannel)
def forward(self,input):
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
avg = self.avg(conv4)
avg_flat = avg.view(avg.shape[0],-1)
fc1 = self.linear1(avg_flat)
cls = self.linear2(fc1)
return cls
def print_conv_wght(model,key):
print(model.state_dict()[key])
if __name__== '__main__':
model = Net(3,3).cuda()
设置requires_grad=False
#遍历参数,根据名称设置requires_grad=False,注意是named_parameters
# for k,v in model.named_parameters():
# if 'linear2' not in k:
# v.requires_grad=False
# optizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-5)
# #另一种方法是只将要更新的参数传入optimizer(和tensorflow一样)
para = []
for k,v in model.named_parameters():
if 'linear2' in k:
para.append(v)
optizer = torch.optim.Adam(para,lr=1e-3,weight_decay=1e-5)
lr_sheduler = torch.optim.lr_scheduler.StepLR(optizer,step_size=10)
loss_func = nn.CrossEntropyLoss()
print(model.state_dict().keys())
for i in range(0,1000):
data = torch.randn(4,3,512,512).cuda()
gt = torch.randint(0,3,(4,),dtype=torch.long).cuda()
output = model(data)
loss = loss_func(output,gt)
if i %10 == 0:
optizer.zero_grad()
print(' conv1.0.weight before update::::::::')
print_conv_wght(model,'conv1.0.weight')
print('linear2.weight before update::::::::')
print_conv_wght(model,'linear2.weight')
loss.backward()
optizer.step()
lr_sheduler.step()
print(' conv1.0.weight after update::::::::')
print_conv_wght(model,'conv1.0.weight')
print('linear2.weight after update::::::::')
print_conv_wght(model,'linear2.weight')

过去已逝,未来太远,只争今朝
浙公网安备 33010602011771号