小土堆pytorch学习—P26-网络模型的保存和读取

模型保存方式有两种,一种是保存网络模型结构+参数,另一种是保存模型的参数。

另外,还有一个针对于自己定义的模型的陷阱问题。

首先说第一种模型保存方式和读取方式——保存网络模型结构+模型参数

model_full_save.py

vgg16=torchvision.models.vgg16(pretrained = False)
torch.save(vgg16 , "model_full_save.pth")
#指定要保存的模型,以及模型的地址
#不仅保存网络模型,也保存网络模型中的参数

model_full_load.py

model = torch.load("model_full_save.pth")
print(model)
#查看网络模型结构

image-20230707094529913

方式2——保存模型参数(官方推荐)

model_param_save.py

torch.save(vgg16.state_dict(),"model_param_save.pth")
#vgg16.state_dict()方法相当于把网络模型的一种状态保存成一个字典,网络模型的参数保存成一个字典

model_param_load.py

model = torch.load('model_param_save.pth')
print(model)#可以看到是保存的网络模型参数字典

====================================
#恢复模型
model = torchvision.models.vgg16(pretrained = False)
#通过网络模型字典形式加载模型
vgg16.load__state_dict(torch.load("model_param_save.pth"))
print(model)

image-20230707095432151

image-20230707095511696

通过在终端中输入ls -all可以看到保存两种方式时模型的大小

image-20230707095721012

陷阱of方式1

自己定一个网络结构,在model_full_save.py文件中

class Tudui(nn.Module):
	def __init__(self):
    	super(Tudui , self).__init__()
    	self.conv1 = nn.Conv2d(3 , 64)
    
    def forward(self, x):
        x = self.conv1(x)
        return x
tudui = Tudui()
torch.save(tudui , "tudui_method1.pth")

用这种方式保存的模型,在model_full_load.py中加载

model = torch.load("tudui_method1.pth")
print(model)

报错提示不能得到Tudui类的属性,因为没有这个类。需要引入,要么直接复制到文件中,要么import到里面去。

image-20230707100015372

import torch
import torchvision
from P26_model_save import Tudui

model = torch.load("tudui_method1.pth")
print(model)

image-20230707100338961

posted @ 2023-07-10 16:02  西红柿爆炒鸡蛋  阅读(55)  评论(0)    收藏  举报