保存
点击查看代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 保存方式一,保存模型结构和参数
torch.save(vgg16_false, "vgg16_method1.pth")
# 保存方式二,保存模型参数
torch.save(vgg16_false.state_dict(), "vgg16_method2.pth")
# 易错
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.covn1 = Conv2d(1, 1, 3)
def forward(self, input):
input = self.covn1(input)
return input
test = Test()
torch.save(test, "test.pth")
print(test)
读取
点击查看代码
import torch
import torchvision
from torch import nn
from test_model_save import *
# 读取方式1
model = torch.load("vgg16_method1.pth")
# print(model)
# 读取方式2
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_method2.pth"))
# print(model)
# 易错
model = torch.load("test.pth")
print(model)