手写数字识别
项目概述
- 目标:使用PyTorch实现一个手写数字识别系统,能够准确识别手写数字图像(0-9)。
- 应用场景:可用于数字图像识别、文档处理、验证码识别等领域。
整体架构设计
- 数据集加载:使用MNIST数据集,包含60,000张训练图像和10,000张测试图像,每张图像为28x28像素的灰度图像。
- 数据预处理:对图像进行归一化处理,将像素值缩放到[0,1]区间,增强模型的泛化能力。
- 网络结构设计:采用简单的卷积神经网络(CNN)架构,包含卷积层、池化层和全连接层。
- 损失计算与优化:使用交叉熵损失函数(CrossEntropyLoss)和随机梯度下降(SGD)优化器。
- 训练与测试:通过多次迭代训练模型,并在测试集上评估模型的性能。
- 发布程序:将训练好的模型保存并加载,用于实际的图像识别任务。
环境配置
使用的库::torch torchvision torchaudio numpy matplotlib cv2
数据集加载
- MNIST数据集:PyTorch提供了
torchvision.datasets.MNIST接口,方便加载MNIST数据集。
train_data = dataset.MNIST(root = "mnist",
train = True,
transform =transforms.ToTensor(),
download = True)
test_data = dataset.MNIST(root = "mnist",
train = False,
transform = transforms.ToTensor(),
download = True)
程序会自动在同路径下生成一个名为mnist的文件夹作为根目录存放下载的 MNIST 训练集和测试集数据
- 数据加载器:使用
DataLoader类对数据进行批处理,设置批量大小(batch size)、是否打乱数据(shuffle)等参数。
train_loader = data_utils.DataLoader(dataset = train_data,
batch_size = 64,
shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
batch_size = 64,
shuffle = True)
网络结构设计
- 卷积层:使用
torch.nn.Conv2d定义卷积层,提取图像的局部特征。 - 激活函数:使用ReLU激活函数,增加网络的非线性能力。
- 池化层:使用最大池化(MaxPooling)降低特征图的维度,减少计算量。
- 全连接层:将卷积层输出的特征图展平后输入到全连接层,输出10个类别(0-9)的概率分布。
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, padding = 2),
#1.卷积操作
torch.nn.BatchNorm2d(32),
#2.归一化
torch.nn.ReLU(),
#3.激活层 Relu函数
torch.nn.MaxPool2d(2)
#4.最大池化
)
# 全链接
self.fc =torch.nn.Linear(14*14*32, 10)
def forward(self, x):
out = self.conv(x)
#将图像数据展开成一维
# 输入的张量(n,c,h,w)
out = out.view(out.size()[0], -1)
out = self.fc(out)
return out
CNN继承父类 torch.nn.Module 定义conv 和 forward对数据进行卷积操作和向前传播
损失计算与优化
- 损失函数:使用
torch.nn.CrossEntropyLoss - 优化器:使用
torch.optim.Adam,设置学习率(learning rate) - 训练过程:
- 对每个批次的数据进行前向传播,计算损失值。
- 通过反向传播计算梯度,并更新网络参数。
训练与测试
- 训练循环:设置训练轮数(epochs),在每个轮次中对训练数据进行迭代训练。
- 模型保存:训练完成后,使用
torch.save保存模型的参数。
for epoch in range(10):
for index1, (images, labels) in enumerate(train_loader):
outputs = cnn(images)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(cnn, "model/mnist_model.pkl")
经过以上步骤我们的训练的手写数字识别模型就保存在了mnist_model.pkl 之中,可以通过torch.load将模型导入
浙公网安备 33010602011771号