手写数字识别系统(待定)

所识别的图片:

手写数字图

源代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# 1. 定义LeNet-5模型(手写数字识别经典模型)
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        # 卷积层:提取图像特征
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # 全连接层:分类决策
        self.fc_layers = nn.Sequential(
            nn.Linear(16*4*4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # 展平特征图
        x = self.fc_layers(x)
        return x

# 2. 图像预处理(匹配模型输入要求)
def preprocess_image(image_path):
    # 转换:灰度化→Resize→张量→标准化(MNIST数据集均值/标准差)
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((28, 28)),  # 适配LeNet-5输入尺寸
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 读取图像并预处理
    image = Image.open(image_path).convert('L')  # 确保灰度读取
    image_tensor = transform(image).unsqueeze(0)  # 增加批次维度(batch=1)
    return image_tensor, image

# 3. 加载预训练权重(使用MNIST训练好的权重,可直接用)
def load_pretrained_model(model_path=None):
    model = LeNet5()
    # 若未提供权重文件,使用MNIST训练后的示例权重(此处用随机训练权重演示,实际可替换为自己训练的.pth文件)
    if model_path:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    else:
        # 随机初始化权重(仅演示,实际需用训练好的权重,可通过MNIST训练获取)
        pass
    model.eval()  # 切换为评估模式
    return model

# 4. 核心识别函数
def recognize_handwritten_digit(image_path, model):
    # 预处理图像
    image_tensor, original_image = preprocess_image(image_path)
    # 模型推理(禁用梯度计算,提升速度)
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output.data, 1)  # 获取预测类别(0-9)
    # 模拟"真实文字"(实际场景需手动标注,此处假设标注为真实值)
    # !!重要:请将下方的true_label替换为你图片的真实手写数字(0-9)
    true_label = 5  # 示例:假设图片真实数字是5,需根据实际情况修改
    # 计算单图准确率(正确为1,错误为0)
    accuracy = 1.0 if predicted.item() == true_label else 0.0
    # 返回结果
    result = {
        "真实文字": str(true_label),
        "识别结果": str(predicted.item()),
        "单图准确率": f"{accuracy:.2f}"
    }
    return result, original_image

# 5. 主函数(执行入口)
if __name__ == "__main__":
    # 配置路径(替换为你的图片路径)
    image_path = r"C:\Users\黄楚玉\Desktop\杂七杂八\手写数字图.jpg"
    # 加载模型(若有训练好的权重文件,可传入model_path参数)
    model = load_pretrained_model()
    # 执行识别
    recognition_result, _ = recognize_handwritten_digit(image_path, model)
    # 打印结果
    print("="*30)
    print("手写数字识别结果")
    print("="*30)
    for key, value in recognition_result.items():
        print(f"{key}: {value}")
    print("="*30)

运行结果如下:

手写数字识别系统

posted @ 2025-12-12 13:45  与尔5  阅读(0)  评论(0)    收藏  举报