Pytorch 断点继训

pytorch 模型的断点继训

pytorch 模型断点继训

1 checkpoint = { "model_state_dict": net.state_dict(),
2  "optimizer_state_dict": optimizer.state_dict(), 
3  "epoch": epoch }
View Code
 1     # ----------
 2     # save model
 3     # ----------
 4     if epoch % args.save_interval == 0:
 5         # set  n_iter=epoch+1
 6         checkpoint = {"model_state_dict": generator.state_dict(),
 7                       "optimizer_state_dict": G_optimizer.state_dict(),
 8                       "epoch": epoch}
 9         path_checkpoint = "./model/celeba/checkpoint_{}_epoch.pkl".format(epoch)
10         torch.save(checkpoint, path_checkpoint)
save model
1 start_iter = args.start_iter
2 if args.pre_trained != '':
3     # ckpt_dict_load = torch.load(args.pre_trained)
4     # start_iter = ckpt_dict_load['n_iter']
5     # generator.load_state_dict(ckpt_dict_load['generator'])
6     checkpoint = torch.load(args.pre_trained)
7     generator.load_state_dict(checkpoint['model_state_dict'])
8     G_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
9     print('Starting from iter ', start_iter)
load model
posted @ 2021-04-16 09:04  临近边缘  阅读(99)  评论(0)    收藏  举报