爽歪歪666
不以物喜,不以己悲,努力才是永恒的主题。
filename = 'cvae_' + str(epoch+1) + '.pkl'
save_path = save_dir / Path(filename)
states = {}
states['model'] = cvae.state_dict() # 模型参数
states['z_dim'] = args.z_dim
states['x_dim'] = args.x_dim
states['s_dim'] = args.s_dim
states['optim'] = cvae.state_dict()
torch.save(states, str(save_path)) #检查点:将states字典存放在save_path文件下 
保存和加载模型的时候,配对的函数:
对于仅保存state_dict()的方式,那保存和加载模型的方式为:
保存:torch.save(model.state_dict(), PATH)
加载:model.laod_state_dict(torch.load(PATH))
一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。

对于保存和加载整个模型的情况:
torch.save(model, PATH)
model = torch.load(PATH)
可以看到,前面的model.load_state_dict()和这里的不同,前面的情况需要你先定义一个模型,然后再load_state_dict()
但是这里load整个模型,会把模型的定义一起load进来。完成了模型的定义和加载参数的两个过程。
详细代码
 1     def save(self):
 2         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
 3 
 4         if not os.path.exists(save_dir):
 5             os.makedirs(save_dir)
 6 
 7         torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
 8         torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
 9 
10         with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
11             pickle.dump(self.train_hist, f)
12 # 使用方法:对模型初始化以后,使用以下方法,加载模型的参数,以至于不用再对数据集进行训练
13     def load(self):
14         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
15 
16         self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
17         self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))

note:

pickle.dump(obj, file, [,protocol]) 序列化对象,将对象obj保存到文件file中去。self.train_hist用于存放模型文件

pickle.load(file) 反序列化对象,将文件中的数据解析为一个python对象。file中有read()接口和readline()接口



posted on 2020-03-14 16:21  爽歪歪666  阅读(6607)  评论(0编辑  收藏  举报