pytorch保存训练后的模型和加载模型
在模型训练之后,需要保存,可以选择这种方式,只保存参数,不全部保存,推荐这种方式。
保存模型
torch.save(model.state_dict(), 'my_model.pth')
加载模型
重新实例化自己的模型,不要和之前的训练模型冲突
loaded_model = NewsClassifier(n_classes=2)
这里strict=False表示不严格对齐参数,不然有可能会报错:embeddings.position_ids
loaded_model.load_state_dict(torch.load('/other_code_files/my_model.pth'), strict=False) loaded_model = loaded_model.to(device)
加载完成后,转换模式
loaded_model.eval()
此处只是记录自己遇到的问题,训练使用的单卡GPU。

浙公网安备 33010602011771号