手写数字识别

项目概述

  • 目标:使用PyTorch实现一个手写数字识别系统,能够准确识别手写数字图像(0-9)。
  • 应用场景:可用于数字图像识别、文档处理、验证码识别等领域。

整体架构设计

  • 数据集加载:使用MNIST数据集,包含60,000张训练图像和10,000张测试图像,每张图像为28x28像素的灰度图像。
  • 数据预处理:对图像进行归一化处理,将像素值缩放到[0,1]区间,增强模型的泛化能力。
  • 网络结构设计:采用简单的卷积神经网络(CNN)架构,包含卷积层、池化层和全连接层。
  • 损失计算与优化:使用交叉熵损失函数(CrossEntropyLoss)和随机梯度下降(SGD)优化器。
  • 训练与测试:通过多次迭代训练模型,并在测试集上评估模型的性能。
  • 发布程序:将训练好的模型保存并加载,用于实际的图像识别任务。

环境配置

​ 使用的库::torch torchvision torchaudio numpy matplotlib cv2

数据集加载

  • MNIST数据集:PyTorch提供了torchvision.datasets.MNIST接口,方便加载MNIST数据集。
train_data = dataset.MNIST(root = "mnist",
                           train = True,
                           transform =transforms.ToTensor(),
                           download = True)
test_data = dataset.MNIST(root = "mnist",
                          train = False,
                          transform = transforms.ToTensor(),
                          download = True)

程序会自动在同路径下生成一个名为mnist的文件夹作为根目录存放下载的 MNIST 训练集和测试集数据

  • 数据加载器:使用DataLoader类对数据进行批处理,设置批量大小(batch size)、是否打乱数据(shuffle)等参数。
train_loader = data_utils.DataLoader(dataset = train_data,
                                     batch_size = 64,
                                     shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
                                     batch_size = 64,
                                     shuffle = True)

网络结构设计

  • 卷积层:使用torch.nn.Conv2d定义卷积层,提取图像的局部特征。
  • 激活函数:使用ReLU激活函数,增加网络的非线性能力。
  • 池化层:使用最大池化(MaxPooling)降低特征图的维度,减少计算量。
  • 全连接层:将卷积层输出的特征图展平后输入到全连接层,输出10个类别(0-9)的概率分布。
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, padding = 2),
                                        #1.卷积操作
                                        torch.nn.BatchNorm2d(32),
                                        #2.归一化
                                        torch.nn.ReLU(),
                                        #3.激活层 Relu函数
                                        torch.nn.MaxPool2d(2)
                                        #4.最大池化
                                        )
        
        # 全链接
        self.fc =torch.nn.Linear(14*14*32, 10)

    def forward(self, x):
        out = self.conv(x)
        #将图像数据展开成一维
        # 输入的张量(n,c,h,w)
        out = out.view(out.size()[0], -1)
        out = self.fc(out)

        return out

CNN继承父类 torch.nn.Module 定义convforward对数据进行卷积操作和向前传播

损失计算与优化

  • 损失函数:使用torch.nn.CrossEntropyLoss
  • 优化器:使用torch.optim.Adam,设置学习率(learning rate)
  • 训练过程
    • 对每个批次的数据进行前向传播,计算损失值。
    • 通过反向传播计算梯度,并更新网络参数。

训练与测试

  • 训练循环:设置训练轮数(epochs),在每个轮次中对训练数据进行迭代训练。
  • 模型保存:训练完成后,使用torch.save保存模型的参数。
for epoch in range(10):
    for index1, (images, labels) in enumerate(train_loader):
        
        outputs = cnn(images)

        loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
torch.save(cnn, "model/mnist_model.pkl")

经过以上步骤我们的训练的手写数字识别模型就保存在了mnist_model.pkl 之中,可以通过torch.load将模型导入

posted on 2025-03-31 20:59  nooobuuu  阅读(135)  评论(0)    收藏  举报

导航