[PyTorch]论文pytorch复现中遇到的BUG

1. zip argument #1 must support iteration

在多gpu训练的时候,自动把你的batch_size分成n_gpu份,每个gpu跑一些数据, 最后再合起来。我之所以出现这个bug是因为返回的时候 返回了一个常量。。

2. torch.nn.DataParallel

在使用torch.nn.DataParallel时候,要先把模型放在gpu上,再进行parallel。

3. model.state_dict()

一般在现有的网络加载预训练模型通常是找到预训练模型在现有的model里面的参数,然后model进行更新,遇到一个bug, 发现加载预训练模型的时候, 效果很差,跟参数没有更新一样,找了一大顿bug,最后才发现,之前是单gpu进行的预训练,现在的模型使用的是多gpu, 打印现在模型的参数你会发现他所有的参数前面都加了一个module. 所以向以前一样更新,没有一个参数会被更新,因此写了一个万能模型参数加载函数。

pretrained_dict = checkpoint['state_dict']
model_dict = self.model.state_dict()
if checkpoint['config']['n_gpu'] > 1 and self.config['n_gpu'] == 1:
    new_dict = OrderedDict()
    for k, v in pretrained_dict.items():
        name = k[7:]
        new_dict[name] = v
    pretrained_dict = new_dict
elif checkpoint['config']['n_gpu'] == 1 and self.config['n_gpu'] > 1:
    new_dict = OrderedDict()
    for k, v in pretrained_dict.items():
        name = "module."+k
        new_dict[name] = v
    pretrained_dict = new_dict
print("The pretrained model's para is following")
for k, v in pretrained_dict.items():
    print(k)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict)
posted @ 2018-12-18 20:43  向前奔跑的少年  阅读(2854)  评论(1编辑  收藏  举报