《记从零实现手写数字识别——PyTorch实战篇》

一、环境搭建
实验采用Python3.8环境,主要依赖库:

  • PyTorch 1.12:深度学习框架
  • Torchvision 0.13:提供MNIST数据集
  • OpenCV 4.6:图像预处理

安装命令:pip install torch torchvision opencv-python

二、实战开发步骤

  1. 数据加载技巧
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # 数据集均值标准差
])

train_set = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data/', train=False, transform=transform)
  1. 改进型网络设计
class EnhancedCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),  
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.Dropout(0.25)  
        )
        self.classifier = nn.Sequential(
            nn.Linear(32*7*7, 128),
            nn.ReLU(),
            nn.Linear(128, 10))
  1. 训练优化技巧
def train_model():
    model = EnhancedCNN()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
    for epoch in range(10):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()

三、效果验证
在测试集上达到98.7%准确率的关键:

  1. 添加BatchNorm层加速收敛
  2. 使用Dropout防止过拟合
  3. 学习率阶梯下降策略

四、模型部署示例

def predict_image(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (28,28))
    img_tensor = transform(255 - img).unsqueeze(0)
    with torch.no_grad():
        pred = torch.argmax(model(img_tensor)).item()
    return pred

思考延伸
尝试使用数据增强(旋转、平移)提升模型鲁棒性,比较不同优化器的性能差异,思考如何将模型部署到移动端应用。

posted @ 2025-04-01 10:12  冒牌诸葛亮  阅读(56)  评论(0)    收藏  举报