25 现有模型的使用及参数
一、官网内容
- 模型

- 数据集

二、案例
分类模型:Vgg16
1.参数
- pretrained:True->已经训练过的模型,并能在数据集上取得一个比较好的效果。
- progress:True->显示下载进度条

2.代码
2.1 下载数据集ImageNet
- 下载需要的包
pip install scipy
2.2 模型下载
- 原始模型
- 预训练模型

- 输出预训练模型的网络结构


2.3 模型改进
- 将输出为1000的类别改为10类,有两种改法:迁移学习(添加线性层),直接改网络结构
添加线性层

修改原有线性层

2.4 完整代码
点击查看代码
import torchvision
## 这个数据集不让直接下载,一共有100多G,自己下载也太大了,所以在这里暂时不下载了
# train_data=torchvision.datasets.ImageNet("./ImageNet_dataset",split='train',download=True,transform=torchvision.transforms.ToTensor())
#原始模型
from torch import nn
vgg16_false=torchvision.models.vgg16(pretrained=False)
#预训练后的模型
vgg16_True=torchvision.models.vgg16(pretrained=True)
#输出预训练后的模型的网络结构
print(vgg16_True)
#将网络修改为10分类,原来是1000分类
## 方法1:添加线性层
vgg16_True.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_True)
## 方法2:直接修改线性层
print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)
模型可以用完,将其保存到其他盘,目前在C盘。

浙公网安备 33010602011771号