《记从零实现手写数字识别——PyTorch实战篇》
一、环境搭建
实验采用Python3.8环境,主要依赖库:
- PyTorch 1.12:深度学习框架
- Torchvision 0.13:提供MNIST数据集
- OpenCV 4.6:图像预处理
安装命令:pip install torch torchvision opencv-python
二、实战开发步骤
- 数据加载技巧
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 数据集均值标准差
])
train_set = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data/', train=False, transform=transform)
- 改进型网络设计
class EnhancedCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(32*7*7, 128),
nn.ReLU(),
nn.Linear(128, 10))
- 训练优化技巧
def train_model():
model = EnhancedCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
for epoch in range(10):
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
scheduler.step()
三、效果验证
在测试集上达到98.7%准确率的关键:
- 添加BatchNorm层加速收敛
- 使用Dropout防止过拟合
- 学习率阶梯下降策略
四、模型部署示例
def predict_image(img_path):
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (28,28))
img_tensor = transform(255 - img).unsqueeze(0)
with torch.no_grad():
pred = torch.argmax(model(img_tensor)).item()
return pred
思考延伸
尝试使用数据增强(旋转、平移)提升模型鲁棒性,比较不同优化器的性能差异,思考如何将模型部署到移动端应用。

浙公网安备 33010602011771号