YOLOv5 Transfer Learning 踩坑

预训练模型有5个class,接下来跑的有8个
直接跑报错

size mismatch for model.24.m.0.weight: copying a param with shape torch.Size([30, 192, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 192, 1, 1]).
size mismatch for model.24.m.0.bias: copying a param with shape torch.Size([30]) from checkpoint,
the shape in current model is torch.Size([39]).
size mismatch for model.24.m.1.weight: copying a param with shape torch.Size([30, 384, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 384, 1, 1]).
size mismatch for model.24.m.1.bias: copying a param with shape torch.Size([30]) from checkpoint,
the shape in current model is torch.Size([39])
size mismatch for model.24.m.2.weight: copying a param with shape torch.Size([30, 768, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 768, 1, 1]).
size mismatch for model.24.m.2.bias: copying a param with shape torch. Size([30]) from checkpoint,
the shape in current model is torch. Size([39]).

解决方法,手动修改这几层的shape,不加载,填入rand

# EMA
    ema = ModelEMA(model) if RANK in {-1, 0} else None

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        ckpt['new_ema'] = []
        for emaa in ckpt['ema'].state_dict():
            ckpt['new_ema'].append(emaa)
        new_weights = []
        for k,v in ckpt['ema'].float().state_dict().items():
            if k.startswith('model.24.m.0.weight'):
                new_v = torch.rand([39, 192, 1, 1])
                new_weights.append(new_v)
            elif k.startswith('model.24.m.1.weight'):
                new_v = torch.rand([39, 384, 1, 1])
                new_weights.append(new_v)
            elif k.startswith('model.24.m.2.weight'):
                new_v = torch.rand([39, 768, 1, 1])
                new_weights.append(new_v)
            elif k.startswith('model.24.m'):
                new_v = torch.rand([39])
                new_weights.append(new_v)
            else:
                 new_weights.append(v)
        ckpt['my_weight'] = dict(zip(ckpt['new_ema'], new_weights))
        if ema and ckpt.get('ema'):
            ema.ema.load_state_dict(ckpt['my_weight'])
            ema.updates = ckpt['updates']

posted @ 2022-07-11 00:00  GhostCai  阅读(1585)  评论(5)    收藏  举报