pytorch保存训练后的模型和加载模型

在模型训练之后,需要保存,可以选择这种方式,只保存参数,不全部保存,推荐这种方式。
保存模型
torch.save(model.state_dict(), 'my_model.pth')
加载模型
重新实例化自己的模型,不要和之前的训练模型冲突
loaded_model = NewsClassifier(n_classes=2)

这里strict=False表示不严格对齐参数,不然有可能会报错:embeddings.position_ids
loaded_model.load_state_dict(torch.load('/other_code_files/my_model.pth'), strict=False)
loaded_model = loaded_model.to(device)

 加载完成后,转换模式

loaded_model.eval()

此处只是记录自己遇到的问题,训练使用的单卡GPU。

posted @ 2024-11-13 21:32  django_start  阅读(121)  评论(0)    收藏  举报