P25_网络模型的保存与读取
25.1网络模型的保存
(1)保存方式1:模型结构+模型参数
点击查看代码
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式1,模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")
不仅保存了网络模型的结构,还保存了网络模型中的参数
保存结果:

(2)保存方式2:模型参数(官方推荐)
点击查看代码
#保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
(3)查看保存方式1和2的文件大小
【老师这里是mac,命令是ls(也就是list),window中是dir(directory的意思),或者以下命令:】
Get-ChildItem -Force | Format-List

25.2网络模型的读取
(1)两种读取方式
点击查看代码
import torch
#方式1->保存方式1,加载模型model1
model1 = torch.load("vgg16_method1.pth")
print(model1)
#方式2->保存方式2,加载模型model2
model2 = torch.load("vgg16_method2.pth")
print(model2)
(2)读取结果
①方式1保存的是模型:
点击查看代码
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
......
②方式2保存的是字典形式,不再是网络模型:
点击查看代码
OrderedDict([('features.0.weight', tensor([[[[-1.3508e-01, -6.5499e-02, 3.7206e-02],
[ 7.5496e-02, 4.2653e-02, -4.6485e-02],
[-4.0531e-02, -1.8580e-02, 2.3787e-02]],
[[-4.1484e-02, -4.8372e-02, 4.7685e-02],
[-4.8050e-02, -7.0310e-02, -5.8736e-02],
[-6.9456e-02, 5.2128e-04, 5.4732e-02]],
......
(3)想把方式2保存的字典格式恢复成网络模型
点击查看代码
#①新建网络模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
#②调用网络模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
25.3总结
(1)模型的保存与读取1:
torch.save(实例, 保存名称), torch.load(实例, 保存名称)
【保存:模型结构+模型参数】
(2)模型的保存与读取2:
torch.save(实例.state_dict(), 保存名称), torch.load(实例.state_dict(), 保存名称)
【保存:模型参数】(官方推荐)
(3)方法1的陷阱:
用方法1的时候要确保读取模型的文件里有定义该模型的类。
浙公网安备 33010602011771号