pytorch(5)----模型加载保存、网络模型库 torchvision.models
torchvision.models 提供了如:VGG、ResNet、Inception 等众多经典的网络模型结构。
安装: pip install torchvision 或 conda install torchvision
import torch from torch import nn from torchvision import models #直接调用 VGG16的网络结构 特征层: 13个卷积、13个激活函数、5个池化、共31层 vgg = models.vgg16() len(vgg.features) #特征层共有31层 # Vgg16的分类层:3个全连接、2个ReLU、2个Dropout;共7层 print(len(vgg.classifier)) # 索引 print(vgg.classifier[-1]) print(vgg.features[24:]) # 在fine-tune 常常使用别人预训练好的模型在自己的数据集上进行训练 # 1、加载 torchvision.models 中自带的预训练好的模型 vgg = models.vgg16(pretrained=True) ''' # 2、 加载本地的预训练模型 或 之前训练过的模型 vgg = models.vgg16() state_dict = torch.load("model path") # 使用 .load_state_dict, 遍历预训练模型的关键字,如果出现在VGG中,则加载预训练参数 vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()}) ''' # 保存模型 # torch.save() # 可以保存网络模型、优化器等信息。 # 当前的状态数据可以通过 .state_dict() 函数获取 # torch.save( { 'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'model_path.pth' } )
浙公网安备 33010602011771号