CNN 核心组成与作用
在上一章里,我们通过实验验证了 MLP 在 MNIST 上表现优异,但在平移、旋转、噪声干扰以及跨数据集时暴露出明显局限性。这也就引出了我们要学习的下一个重要模型 —— 卷积神经网络(CNN, Convolutional Neural Network)。CNN 的出现,正是为了解决 MLP 丢失空间结构、参数冗余、泛化能力差 等问题。下面,我们用最直观、最通俗的话来理解 CNN 的核心组成。
- 卷积层(Convolution Layer)
• 是什么:就像拿着一个小“滤镜”在图片上滑动,局部扫描像素。
• 能干嘛:找到局部特征,比如边缘、角点、纹理。
• 为什么有用:
• 相比 MLP 的“拉平处理”,卷积保留了图像的空间结构。
• 卷积核参数共享,能大幅减少模型参数。
• 天然具有 平移等变性,物体挪一点位置,特征也跟着挪,不会完全当成新模式。 - 激活函数(ReLU 等)
• 是什么:给线性的输出加一个“弯曲”,常用的 ReLU 就是把负数压成 0。
• 能干嘛:让模型能学到复杂的非线性边界。
• 为什么有用:如果没有激活函数,不管堆多少层卷积,最终还是线性模型。 - 池化层(Pooling Layer)
• 是什么:在一个小区域里,取最大值(Max Pool)或平均值(Avg Pool)。
• 能干嘛:缩小图片尺寸,保留主要特征,丢掉冗余细节。
• 为什么有用:让模型对小范围的平移和噪声更不敏感,也能减少计算量。 - 归一化层(Batch Normalization)
• 是什么:把中间特征做标准化,再通过可学习的参数缩放/平移。
• 能干嘛:让训练更稳定,梯度不容易爆炸或消失。
• 为什么有用:加快收敛,减少过拟合,是现代 CNN 的“标配”。 - 残差连接(Residual / Skip Connection)
• 是什么:输入绕个“捷径”加到输出上。
• 能干嘛:解决网络太深时训练变差的问题。
• 为什么有用:ResNet 的核心设计,让 CNN 可以轻松堆到上百层。 - 全连接层(Fully Connected Layer)
• 是什么:在 CNN 最后阶段,用全连接层把特征映射成具体的类别分数。
• 能干嘛:完成最终的分类或回归任务。
• 为什么有用:卷积负责“特征提取”,全连接负责“决策”。 - 数据增强(Data Augmentation)
• 是什么:对输入图片做随机裁剪、翻转、旋转、加噪声等。
• 能干嘛:让模型“见多识广”,学会忽略无关扰动。
• 为什么有用:直接提升泛化能力,比单纯依赖正则化更有效。
CNN 的关键点就是:- 局部感受野 —— 专注小区域,逐层叠加形成全局理解;
- 权重共享 —— 大幅减少参数;
- 平移等变性 + 池化近似不变性 —— 对位置、尺度更鲁棒;
- 层次化特征 —— 从边缘到纹理,再到整体语义。
这就是为什么 CNN 在视觉领域一举超越 MLP,成为计算机视觉的基石。
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ==== 配置参数 ====
DATASET = "mnist" # "mnist" 或 "fashion"
EPOCHS = 5
BATCH_SIZE = 128
LR = 1e-3
USE_AUG = True
# ==== 设备选择 ====
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
else "cpu")
print("Using device:", device)
# ==== CNN 模型 ====
class SmallCNN(nn.Module):
def __init__(self, in_ch=1, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_ch, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 28->14
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 14->7
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64*7*7, 128), nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x): return self.fc(self.features(x))
# ==== 数据加载 ====
if DATASET == "mnist":
TrainSet = torchvision.datasets.MNIST
TestSet = torchvision.datasets.MNIST
else:
TrainSet = torchvision.datasets.FashionMNIST
TestSet = torchvision.datasets.FashionMNIST
train_tf = T.Compose([
T.RandomAffine(degrees=10, translate=(0.05,0.05)) if USE_AUG else T.Lambda(lambda x:x),
T.ToTensor()
])
test_tf = T.ToTensor()
train_set = TrainSet("./data", train=True, download=True, transform=train_tf)
test_set = TestSet("./data", train=False, download=True, transform=test_tf)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=512, shuffle=False)
# ==== 训练与评估函数 ====
def train_one_epoch(model, loader, opt):
model.train(); total, correct, loss_sum = 0,0,0
for x,y in loader:
x,y = x.to(device), y.to(device)
opt.zero_grad()
out = model(x)
loss = F.cross_entropy(out,y)
loss.backward(); opt.step()
loss_sum += loss.item()*y.size(0)
pred = out.argmax(1); correct += (pred==y).sum().item(); total+=y.size(0)
return loss_sum/total, correct/total
@torch.no_grad()
def evaluate(model, loader):
model.eval(); total, correct, loss_sum = 0,0,0
for x,y in loader:
x,y = x.to(device), y.to(device)
out = model(x)
loss = F.cross_entropy(out,y)
loss_sum += loss.item()*y.size(0)
pred = out.argmax(1); correct += (pred==y).sum().item(); total+=y.size(0)
return loss_sum/total, correct/total
# ==== 主训练 ====
model = SmallCNN(in_ch=1,num_classes=10).to(device)
opt = torch.optim.Adam(model.parameters(), lr=LR)
for ep in range(1,EPOCHS+1):
tr_loss,tr_acc = train_one_epoch(model,train_loader,opt)
te_loss,te_acc = evaluate(model,test_loader)
print(f"Epoch {ep}/{EPOCHS} | "
f"train acc {tr_acc*100:.2f}% | test acc {te_acc*100:.2f}%")
# ==== 可视化部分预测 ====
x,y = next(iter(test_loader))
x_vis = x[:12]; y = y[:12]
with torch.no_grad():
pred = model(x_vis.to(device)).argmax(1).cpu()
plt.figure(figsize=(10,4))
for i in range(12):
plt.subplot(2,6,i+1)
plt.imshow(x_vis[i,0],cmap="gray"); plt.axis("off")
plt.title(f"P:{pred[i].item()} / T:{y[i].item()}",fontsize=9)
plt.show()


浙公网安备 33010602011771号