数字识别(非汉字版)

1 训练

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

设置中文字体

matplotlib.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
matplotlib.rcParams['axes.unicode_minus'] = False

定义神经网络类

class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
# 四个全连接层
self.fc1 = torch.nn.Linear(2828, 64) # 输入层:2828像素
self.fc2 = torch.nn.Linear(64, 64) # 隐藏层1:64个节点
self.fc3 = torch.nn.Linear(64, 64) # 隐藏层2:64个节点
self.fc4 = torch.nn.Linear(64, 10) # 输出层:10个类别(0-9)

def forward(self, x):
    """前向传播"""
    x = torch.nn.functional.relu(self.fc1(x))
    x = torch.nn.functional.relu(self.fc2(x))
    x = torch.nn.functional.relu(self.fc3(x))
    # 输出层使用log_softmax
    x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
    return x

def get_data_loader(is_train, batch_size=15):
"""获取数据加载器"""
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", is_train, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=batch_size, shuffle=True)

def evaluate(test_data, net):
"""评估神经网络准确率"""
n_correct = 0
n_total = 0
net.eval() # 设置为评估模式

with torch.no_grad():
    for (x, y) in test_data:
        outputs = net(x.view(-1, 28*28))
        for i, output in enumerate(outputs):
            if torch.argmax(output) == y[i]:
                n_correct += 1
            n_total += 1

return n_correct / n_total

def show_predictions(test_data, net, num_images=3):
"""显示预测结果"""
net.eval()
fig, axes = plt.subplots(1, num_images, figsize=(12, 4))

with torch.no_grad():
    for i, (x, y) in enumerate(test_data):
        if i >= num_images:
            break
        
        # 获取预测结果
        output = net(x[0].view(-1, 28*28))
        predict = torch.argmax(output)
        
        # 显示图像和预测结果
        axes[i].imshow(x[0].view(28, 28), cmap='gray')
        axes[i].set_title(f'预测: {int(predict)}\n实际: {int(y[0])}')
        axes[i].axis('off')

plt.tight_layout()
plt.savefig('mnist_predictions.png', dpi=150, bbox_inches='tight')
print("预测结果已保存为 mnist_predictions.png")
plt.show()

def main():
"""主函数"""
print("开始MNIST手写数字识别训练...")

# 1. 获取训练集和测试集
train_data = get_data_loader(is_train=True)
test_data = get_data_loader(is_train=False)

# 2. 初始化神经网络
net = Net()
print("神经网络结构:")
print(net)

# 3. 打印初始准确率
initial_acc = evaluate(test_data, net)
print(f"初始准确率: {initial_acc:.4f}")

# 4. 训练神经网络
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for epoch in range(2):
    net.train()  # 设置为训练模式
    total_loss = 0
    
    for batch_idx, (x, y) in enumerate(train_data):
        net.zero_grad()
        output = net(x.view(-1, 28*28))
        loss = torch.nn.functional.nll_loss(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
    
    # 评估当前准确率
    accuracy = evaluate(test_data, net)
    print(f"Epoch {epoch} 完成,准确率: {accuracy:.4f}")


# 6. 保存模型
torch.save(net.state_dict(), 'model.pth')
print("模型已保存为 model.pth")

if name == 'main':
main()

2 识别
import os
import torch
import tkinter as tk
from tkinter import messagebox, filedialog
from PIL import Image, ImageDraw, ImageOps
from torchvision import transforms, datasets

解决OpenMP库冲突

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

1. 定义神经网络结构(必须与训练时完全一致)

class MNISTNet(torch.nn.Module):
def init(self):
super(MNISTNet, self).init()
self.fc1 = torch.nn.Linear(28*28, 64)
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 10)

def forward(self, x):
    x = torch.nn.functional.relu(self.fc1(x))
    x = torch.nn.functional.relu(self.fc2(x))
    x = torch.nn.functional.relu(self.fc3(x))
    x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
    return x

2. 加载模型并验证

def load_and_verify_model(model_path):
"""加载模型并使用MNIST测试集验证"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")

model = MNISTNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# 验证模型是否正常工作(使用MNIST测试集的前5张图片)
test_dataset = datasets.MNIST(
    root='./data', 
    train=False, 
    download=True,
    transform=transforms.ToTensor()
)

# 取前5张测试图片
test_images = [test_dataset[i][0] for i in range(5)]
test_labels = [test_dataset[i][1] for i in range(5)]

with torch.no_grad():
    correct = 0
    for img, label in zip(test_images, test_labels):
        output = model(img.view(-1, 28*28))
        pred = torch.argmax(output).item()
        if pred == label:
            correct += 1

if correct < 3:  # 至少正确识别3张才认为模型有效
    raise RuntimeError(f"模型验证失败,仅正确识别{correct}/5张测试图片,可能模型文件损坏或训练不充分")

print(f"模型验证成功,正确识别{correct}/5张测试图片")
return model

3. 手写识别界面(修复预处理问题)

class DrawingInterface:
def init(self, root, model):
self.root = root
self.root.title("手写数字识别(修复版)")
self.model = model

    # 画布设置(280x280,便于缩放至28x28)
    self.canvas_size = 280
    self.canvas = tk.Canvas(
        root, 
        width=self.canvas_size, 
        height=self.canvas_size, 
        bg="white", 
        cursor="cross"
    )
    self.canvas.pack(pady=20)
    
    # 创建图像对象(初始为白色背景)
    self.image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
    self.draw = ImageDraw.Draw(self.image)
    
    # 绑定鼠标事件
    self.canvas.bind("<B1-Motion>", self.draw_stroke)
    self.canvas.bind("<ButtonRelease-1>", self.reset_last_pos)
    
    # 按钮区域
    self.buttons_frame = tk.Frame(root)
    self.buttons_frame.pack(pady=10)
    
    # 识别按钮
    self.recognize_btn = tk.Button(
        self.buttons_frame, 
        text="识别数字", 
        command=self.recognize_digit,
        width=15,
        font=("Arial", 12)
    )
    self.recognize_btn.grid(row=0, column=0, padx=10)
    
    # 清除按钮
    self.clear_btn = tk.Button(
        self.buttons_frame, 
        text="清除画布", 
        command=self.clear_canvas,
        width=15,
        font=("Arial", 12)
    )
    self.clear_btn.grid(row=0, column=1, padx=10)
    
    # 测试按钮(用标准图片测试)
    self.test_btn = tk.Button(
        self.buttons_frame, 
        text="测试模型", 
        command=self.test_model,
        width=15,
        font=("Arial", 12)
    )
    self.test_btn.grid(row=0, column=2, padx=10)
    
    # 结果显示
    self.result_var = tk.StringVar()
    self.result_var.set("请在画布上书写数字(0-9),然后点击识别")
    self.result_label = tk.Label(
        root, 
        textvariable=self.result_var, 
        font=("Arial", 14)
    )
    self.result_label.pack(pady=20)
    
    # 鼠标位置记录
    self.last_x = None
    self.last_y = None
    self.last_y = None

def draw_stroke(self, event):
    """绘制线条(确保笔迹粗细适中)"""
    current_x, current_y = event.x, event.y
    
    if self.last_x and self.last_y:
        # 绘制线条(宽度15,圆润笔触)
        self.canvas.create_line(
            self.last_x, self.last_y, current_x, current_y,
            width=15, fill="black", capstyle=tk.ROUND, smooth=tk.TRUE
        )
        self.draw.line(
            [(self.last_x, self.last_y), (current_x, current_y)],
            fill=0,  # 黑色(与MNIST一致)
            width=15
        )
    
    self.last_x = current_x
    self.last_y = current_y

def reset_last_pos(self, event):
    self.last_x = None
    self.last_y = None

def clear_canvas(self):
    self.canvas.delete("all")
    self.image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
    self.draw = ImageDraw.Draw(self.image)
    self.result_var.set("请在画布上书写数字(0-9),然后点击识别")

def preprocess_image(self, image):
    """图像预处理(关键修复)"""
    # 1. 缩放到28x28(与MNIST尺寸一致)
    img = image.resize((28, 28), Image.LANCZOS)
    
    # 2. 确保是白底黑字(MNIST标准)
    # 反转图像(如果背景是黑色)
    img = ImageOps.invert(img) if img.getextrema()[0] < 128 else img
    
    # 3. 归一化(使用MNIST的均值和标准差)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # 必须与训练时一致
    ])
    
    return transform(img).unsqueeze(0)  # 增加批次维度

def recognize_digit(self):
    """识别手写数字(修复版)"""
    try:
        # 预处理图像
        image_tensor = self.preprocess_image(self.image)
        
        # 模型预测
        with torch.no_grad():
            output = self.model(image_tensor.view(-1, 28*28))
            probabilities = torch.exp(output)  # 转换为概率
            confidence, predicted_idx = torch.max(probabilities, 1)
            
            # 输出所有类别的概率(用于调试)
            print("各数字概率:", [f"{i}: {p:.2%}" for i, p in enumerate(probabilities[0])])
        
        # 显示结果
        result = f"识别结果: {predicted_idx.item()},置信度: {confidence.item():.2%}"
        self.result_var.set(result)
        
    except Exception as e:
        messagebox.showerror("识别错误", f"错误: {str(e)}")

def test_model(self):
    """用MNIST测试集验证模型是否正常工作"""
    try:
        # 加载MNIST测试集
        test_dataset = datasets.MNIST(
            root='./data', 
            train=False, 
            download=True,
            transform=transforms.ToTensor()
        )
        
        # 随机选一张图片测试
        import random
        idx = random.randint(0, 1000)
        img, label = test_dataset[idx]
        
        # 预测
        with torch.no_grad():
            output = self.model(img.view(-1, 28*28))
            pred = torch.argmax(output).item()
        
        result = f"模型测试: 实际数字{label},预测结果{pred},{'正确' if pred == label else '错误'}"
        self.result_var.set(result)
        
    except Exception as e:
        messagebox.showerror("测试错误", f"错误: {str(e)}")

4. 主函数

def main():
# 模型路径(请替换为你的模型文件路径)
model_path = "model.pth"

try:
    # 加载并验证模型
    print(f"正在加载模型: {model_path}")
    model = load_and_verify_model(model_path)
    print("模型加载成功")
    
    # 启动界面
    root = tk.Tk()
    app = DrawingInterface(root, model)
    root.mainloop()
    
except Exception as e:
    print(f"初始化失败: {str(e)}")
    messagebox.showerror("初始化错误", f"无法启动应用: {str(e)}")

if name == "main":
main()

posted @ 2025-10-30 23:31  飕飕  阅读(11)  评论(0)    收藏  举报