手写数字识别系统(待定)
所识别的图片:

源代码
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# 1. 定义LeNet-5模型(手写数字识别经典模型)
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super(LeNet5, self).__init__()
# 卷积层:提取图像特征
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 全连接层:分类决策
self.fc_layers = nn.Sequential(
nn.Linear(16*4*4, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1) # 展平特征图
x = self.fc_layers(x)
return x
# 2. 图像预处理(匹配模型输入要求)
def preprocess_image(image_path):
# 转换:灰度化→Resize→张量→标准化(MNIST数据集均值/标准差)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)), # 适配LeNet-5输入尺寸
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 读取图像并预处理
image = Image.open(image_path).convert('L') # 确保灰度读取
image_tensor = transform(image).unsqueeze(0) # 增加批次维度(batch=1)
return image_tensor, image
# 3. 加载预训练权重(使用MNIST训练好的权重,可直接用)
def load_pretrained_model(model_path=None):
model = LeNet5()
# 若未提供权重文件,使用MNIST训练后的示例权重(此处用随机训练权重演示,实际可替换为自己训练的.pth文件)
if model_path:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
else:
# 随机初始化权重(仅演示,实际需用训练好的权重,可通过MNIST训练获取)
pass
model.eval() # 切换为评估模式
return model
# 4. 核心识别函数
def recognize_handwritten_digit(image_path, model):
# 预处理图像
image_tensor, original_image = preprocess_image(image_path)
# 模型推理(禁用梯度计算,提升速度)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output.data, 1) # 获取预测类别(0-9)
# 模拟"真实文字"(实际场景需手动标注,此处假设标注为真实值)
# !!重要:请将下方的true_label替换为你图片的真实手写数字(0-9)
true_label = 5 # 示例:假设图片真实数字是5,需根据实际情况修改
# 计算单图准确率(正确为1,错误为0)
accuracy = 1.0 if predicted.item() == true_label else 0.0
# 返回结果
result = {
"真实文字": str(true_label),
"识别结果": str(predicted.item()),
"单图准确率": f"{accuracy:.2f}"
}
return result, original_image
# 5. 主函数(执行入口)
if __name__ == "__main__":
# 配置路径(替换为你的图片路径)
image_path = r"C:\Users\黄楚玉\Desktop\杂七杂八\手写数字图.jpg"
# 加载模型(若有训练好的权重文件,可传入model_path参数)
model = load_pretrained_model()
# 执行识别
recognition_result, _ = recognize_handwritten_digit(image_path, model)
# 打印结果
print("="*30)
print("手写数字识别结果")
print("="*30)
for key, value in recognition_result.items():
print(f"{key}: {value}")
print("="*30)
运行结果如下:

浙公网安备 33010602011771号