Samar-blog

导航

P25_网络模型的保存与读取

25.1网络模型的保存

(1)保存方式1:模型结构+模型参数

点击查看代码
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)

#保存方式1,模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")

不仅保存了网络模型的结构,还保存了网络模型中的参数

保存结果:
P25_保存方式1

(2)保存方式2:模型参数(官方推荐)

点击查看代码
#保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

(3)查看保存方式1和2的文件大小

【老师这里是mac,命令是ls(也就是list),window中是dir(directory的意思),或者以下命令:】

Get-ChildItem -Force | Format-List
P25_保存方式1和2的文件大小

25.2网络模型的读取

(1)两种读取方式

点击查看代码
import torch

#方式1->保存方式1,加载模型model1
model1 = torch.load("vgg16_method1.pth")
print(model1)

#方式2->保存方式2,加载模型model2
model2 = torch.load("vgg16_method2.pth")
print(model2)

(2)读取结果

①方式1保存的是模型:
点击查看代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ......
②方式2保存的是字典形式,不再是网络模型:
点击查看代码
OrderedDict([('features.0.weight', tensor([[[[-1.3508e-01, -6.5499e-02,  3.7206e-02],
          [ 7.5496e-02,  4.2653e-02, -4.6485e-02],
          [-4.0531e-02, -1.8580e-02,  2.3787e-02]],

         [[-4.1484e-02, -4.8372e-02,  4.7685e-02],
          [-4.8050e-02, -7.0310e-02, -5.8736e-02],
          [-6.9456e-02,  5.2128e-04,  5.4732e-02]],
          ......

(3)想把方式2保存的字典格式恢复成网络模型

点击查看代码
#①新建网络模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
#②调用网络模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

25.3总结

(1)模型的保存与读取1:

torch.save(实例, 保存名称), torch.load(实例, 保存名称)

【保存:模型结构+模型参数】

(2)模型的保存与读取2:

torch.save(实例.state_dict(), 保存名称), torch.load(实例.state_dict(), 保存名称)

【保存:模型参数】(官方推荐)

(3)方法1的陷阱:

用方法1的时候要确保读取模型的文件里有定义该模型的类。

posted on 2025-11-24 19:57  风居住的街道DYL  阅读(0)  评论(0)    收藏  举报