pytorch保存模型并记录最优模型

 

# https://github.com/tczhangzhi/pytorch-distributed/blob/master/distributed.py

# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)

if args.local_rank == 0:
    save_checkpoint(
               {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_acc1': best_acc1,
                }, is_best)



def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

  

shutil.copyfile(filename, 'model_best.pth.tar') # 如果是当前最优精度的模型,则保存时维护一个副本

 

posted on 2021-04-01 14:44  那抹阳光1994  阅读(3399)  评论(0编辑  收藏  举报

导航