pytorch(5)----模型加载保存、网络模型库 torchvision.models

torchvision.models 提供了如:VGG、ResNet、Inception 等众多经典的网络模型结构。

安装: pip install torchvision  或 conda install torchvision

import torch
from torch import nn
from torchvision import models

#直接调用 VGG16的网络结构 特征层: 13个卷积、13个激活函数、5个池化、共31层
vgg = models.vgg16()
len(vgg.features) #特征层共有31层

# Vgg16的分类层:3个全连接、2个ReLU、2个Dropout;共7层
print(len(vgg.classifier))
# 索引
print(vgg.classifier[-1])
print(vgg.features[24:])

# 在fine-tune 常常使用别人预训练好的模型在自己的数据集上进行训练
# 1、加载 torchvision.models 中自带的预训练好的模型
vgg = models.vgg16(pretrained=True)

'''
# 2、 加载本地的预训练模型 或 之前训练过的模型
vgg = models.vgg16()
state_dict = torch.load("model path")
# 使用 .load_state_dict, 遍历预训练模型的关键字,如果出现在VGG中,则加载预训练参数
vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})

'''
# 保存模型
# torch.save()
# 可以保存网络模型、优化器等信息。
# 当前的状态数据可以通过 .state_dict() 函数获取
# torch.save( { 'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'model_path.pth' } )

 

posted on 2020-02-14 18:11  feihu_h  阅读(590)  评论(0)    收藏  举报

导航