pytorch保存模型

对比tf来说简单了好几个量级。。

1. 先建立一个字典,保存三个参数:
# 只保存了参数
state = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, model_dir)

2. 当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。 
checkpoint
= torch.load(model_dir)
# 加载参数
model.load_state_dict(checkpoint[
'net'])
optimizer.load_state_dict(checkpoint[
'optimizer'])
start_epoch
= checkpoint['epoch'] + 1

3. 保存完整模型
torch.save(model, model_dir) 
model = torch.load(model_dir)

ref:https://www.cnblogs.com/qinduanyinghua/p/9311410.html

posted @ 2020-11-27 15:09  小小马进阶笔记  阅读(409)  评论(0)    收藏  举报