Pytorch 断点继训
1 checkpoint = { "model_state_dict": net.state_dict(), 2 "optimizer_state_dict": optimizer.state_dict(), 3 "epoch": epoch }
1 # ---------- 2 # save model 3 # ---------- 4 if epoch % args.save_interval == 0: 5 # set n_iter=epoch+1 6 checkpoint = {"model_state_dict": generator.state_dict(), 7 "optimizer_state_dict": G_optimizer.state_dict(), 8 "epoch": epoch} 9 path_checkpoint = "./model/celeba/checkpoint_{}_epoch.pkl".format(epoch) 10 torch.save(checkpoint, path_checkpoint)
1 start_iter = args.start_iter 2 if args.pre_trained != '': 3 # ckpt_dict_load = torch.load(args.pre_trained) 4 # start_iter = ckpt_dict_load['n_iter'] 5 # generator.load_state_dict(ckpt_dict_load['generator']) 6 checkpoint = torch.load(args.pre_trained) 7 generator.load_state_dict(checkpoint['model_state_dict']) 8 G_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 9 print('Starting from iter ', start_iter)

浙公网安备 33010602011771号