pytorch load模型时报错:missing keys in state_dict error in loading state_dict
具体问题报错信息为:
missing keys in state_dict error in loading state_dict
诊断出错原因,采用了如下代码
# 获取当前模型的状态字典
current_state_dict = model.state_dict()
# 获取保存的状态字典
saved_state_dict = torch.load('model.pth')['model_state_dict']
# 检查缺失的键
missing_keys = [key for key in current_state_dict.keys()
if key not in saved_state_dict]
# 检查多余的键
extra_keys = [key for key in saved_state_dict.keys()
if key not in current_state_dict]
print(f"Missing keys: {missing_keys}")
print(f"Extra keys: {extra_keys}")
然后对比了输出后发现,没有一个对应上的,再使用调试模式可以发现:
保存的模型中所有的keys都加了一个module
的前缀,具体来看就是
conv1.weight(模型的) --> module.conv1.weight(保存的)
查询以后可以发现,其原因在于训练时使用了nn.DataParallel(原网址)
ptrblck
Aug 2018
It seems you’ve used nn.DataParallel to save the model.
You could wrap your current model again in nn.DataParallel or just remove the .module keys.
Here is a similar thread with some suggestions.
主要的作用是进行多GPU训练的,可以参考这个博客
nn.DataParallel详细解析
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0):
这个函数主要有三个参数:
module:即模型,此处注意,虽然输入数据被均分到不同gpu上,但每个gpu上都要拷贝一份模型。
device_ids:即参与训练的gpu列表,例如三块卡, device_ids = [0,1,2]。
output_device:指定输出gpu,一般省略。在省略的情况下,默认为第一块卡,即索引为0的卡。此处有一个问题,输入计算是被几块卡均分的,但输出loss的计算是由这一张卡独自承担的,这就造成这张卡所承受的计算量要大于其他参与训练的卡。
原文链接:https://blog.csdn.net/qq_38410428/article/details/119392993
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
回到最初的问题上面,现在知道是keys中多了module
,那么把他去掉就行:
# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)