CNN实战

一.使用VGG模型进行猫狗大战:

1.下载数据:将图片放置于本地文件夹

2.数据处理:

设置VGG格式,将图片进行归一化处理

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
data_dir = './dogscats'
dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
         for x in ['train', 'valid']}
dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes

3.创建 VGC Model

我们直接使用预训练好的 VGG 模型。在这部分代码中,对输入的5个图片利用VGG模型进行预测,同时,使用softmax对结果进行处理,随后展示了识别结果。可以看到,识别结果是比较非常准确的。

 4.修改最后一层,冻结前面层的参数

VGG 模型由三种元素组成:

卷积层(CONV)

全连接层(FC)

池化(Pool)

我们的目标是使用预训练好的模型,因此,需将nn.linear最后一层的参数改为2,即为2类(cat & dog),同时冻结前面的参数,这样在反向传播时只会更新最后一层的参数,不影响预训练结果。

print(model_vgg)
model_vgg_new = model_vgg;
for param in model_vgg_new.parameters():
    param.requires_grad = False
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2)
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1)
model_vgg_new = model_vgg_new.to(device)
print(model_vgg_new.classifier)

5.训练并测试全连接层

 

 经过优化之后,准确率达到了96.6%。

 

posted @ 2021-10-24 13:53  罗宇浩11  阅读(104)  评论(1)    收藏  举报