CNN实战
一.使用VGG模型进行猫狗大战:
1.下载数据:将图片放置于本地文件夹
2.数据处理:
设置VGG格式,将图片进行归一化处理
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
data_dir = './dogscats'
dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
for x in ['train', 'valid']}
dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes
3.创建 VGC Model
我们直接使用预训练好的 VGG 模型。在这部分代码中,对输入的5个图片利用VGG模型进行预测,同时,使用softmax对结果进行处理,随后展示了识别结果。可以看到,识别结果是比较非常准确的。


4.修改最后一层,冻结前面层的参数
VGG 模型由三种元素组成:
卷积层(CONV)
全连接层(FC)
池化(Pool)
我们的目标是使用预训练好的模型,因此,需将nn.linear最后一层的参数改为2,即为2类(cat & dog),同时冻结前面的参数,这样在反向传播时只会更新最后一层的参数,不影响预训练结果。
print(model_vgg) model_vgg_new = model_vgg; for param in model_vgg_new.parameters(): param.requires_grad = False model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2) model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1) model_vgg_new = model_vgg_new.to(device) print(model_vgg_new.classifier)
5.训练并测试全连接层


经过优化之后,准确率达到了96.6%。

浙公网安备 33010602011771号