25 现有模型的使用及参数

一、官网内容

  • 模型

image

  • 数据集

image

二、案例

分类模型:Vgg16

1.参数

  • pretrained:True->已经训练过的模型,并能在数据集上取得一个比较好的效果。
  • progress:True->显示下载进度条

image

2.代码

2.1 下载数据集ImageNet

  • 下载需要的包
pip install scipy

2.2 模型下载

  • 原始模型
  • 预训练模型

image

  • 输出预训练模型的网络结构

image

image

2.3 模型改进

  • 将输出为1000的类别改为10类,有两种改法:迁移学习(添加线性层),直接改网络结构

添加线性层

image

修改原有线性层

image

2.4 完整代码

点击查看代码
import torchvision

## 这个数据集不让直接下载,一共有100多G,自己下载也太大了,所以在这里暂时不下载了
# train_data=torchvision.datasets.ImageNet("./ImageNet_dataset",split='train',download=True,transform=torchvision.transforms.ToTensor())

#原始模型
from torch import nn

vgg16_false=torchvision.models.vgg16(pretrained=False)
#预训练后的模型
vgg16_True=torchvision.models.vgg16(pretrained=True)

#输出预训练后的模型的网络结构
print(vgg16_True)

#将网络修改为10分类,原来是1000分类
## 方法1:添加线性层
vgg16_True.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_True)
## 方法2:直接修改线性层
print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

模型可以用完,将其保存到其他盘,目前在C盘。

posted @ 2022-05-24 22:05  Trouvaille_fighting  阅读(73)  评论(0)    收藏  举报