21、现有网络模型的使用以及修改

1、网络模型在pytorch里面的torchvision里面torchvision.models,是关于图像类的网络模型

2、简单以一个分类模型为例子:   VGG(最常用的是VGG16和VGG19)

 

 pretrained:

   如果是true的话,说明在ImageNet数据集上,模型的参数是都训练好的; 如果是False的话,说明模型的参数是初始化的,没有训练好。

vgg16_false=torchvision.models.vgg16(pretrained=False)      #当pretrained为 False的时候只是加载网络模型。是不需要对网络模型的参数进行下载的
vgg16_true=torchvision.models.vgg16(pretrained=True)   #pretrained=True时,需要下载网络模型,下载模型里的参数

print(vgg16_true)

 

 

progress:

  如果是True,显示下载进度条; False则不显示

3、ImageNet数据集:

 

 

 

 4、修改现有模型

train_data=torchvision.datasets.CIFAR10('../../dataset/CIFAR10',train=False,
                                        transform=torchvision.transforms.ToTensor(),download=True)
'''如何利用现有的网络模型,去改动它的结构;比如说想让VGG是10分类任务,也就是让输出特征是10;可以有两种'''
#1、再添加一个线性层
vgg16_true.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
#add_module()里面两个参数,一个是字符串型,给要加的模块起个名字,第二个是要加的模块,可以直接是一层网络,也可以是一个序列
print(vgg16_true)

输出:

 

 

# 2、如果想在序列里面添加可以这样网络模型.想要加的位置.add_moudle()
vgg16_true.classifier.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
print(vgg16_true)
# 3、不想添加的话,可以进行修改
#对模型中的classifier中的第6层进行修改
vgg16_false.classifier[6]=nn.Linear(4096,10)

 

posted @ 2023-02-24 18:43  bokeAR  阅读(175)  评论(0)    收藏  举报