数字识别(非汉字版)
1 训练
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
设置中文字体
matplotlib.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
matplotlib.rcParams['axes.unicode_minus'] = False
定义神经网络类
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
# 四个全连接层
self.fc1 = torch.nn.Linear(2828, 64) # 输入层:2828像素
self.fc2 = torch.nn.Linear(64, 64) # 隐藏层1:64个节点
self.fc3 = torch.nn.Linear(64, 64) # 隐藏层2:64个节点
self.fc4 = torch.nn.Linear(64, 10) # 输出层:10个类别(0-9)
def forward(self, x):
"""前向传播"""
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
# 输出层使用log_softmax
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
def get_data_loader(is_train, batch_size=15):
"""获取数据加载器"""
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", is_train, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=batch_size, shuffle=True)
def evaluate(test_data, net):
"""评估神经网络准确率"""
n_correct = 0
n_total = 0
net.eval() # 设置为评估模式
with torch.no_grad():
for (x, y) in test_data:
outputs = net(x.view(-1, 28*28))
for i, output in enumerate(outputs):
if torch.argmax(output) == y[i]:
n_correct += 1
n_total += 1
return n_correct / n_total
def show_predictions(test_data, net, num_images=3):
"""显示预测结果"""
net.eval()
fig, axes = plt.subplots(1, num_images, figsize=(12, 4))
with torch.no_grad():
for i, (x, y) in enumerate(test_data):
if i >= num_images:
break
# 获取预测结果
output = net(x[0].view(-1, 28*28))
predict = torch.argmax(output)
# 显示图像和预测结果
axes[i].imshow(x[0].view(28, 28), cmap='gray')
axes[i].set_title(f'预测: {int(predict)}\n实际: {int(y[0])}')
axes[i].axis('off')
plt.tight_layout()
plt.savefig('mnist_predictions.png', dpi=150, bbox_inches='tight')
print("预测结果已保存为 mnist_predictions.png")
plt.show()
def main():
"""主函数"""
print("开始MNIST手写数字识别训练...")
# 1. 获取训练集和测试集
train_data = get_data_loader(is_train=True)
test_data = get_data_loader(is_train=False)
# 2. 初始化神经网络
net = Net()
print("神经网络结构:")
print(net)
# 3. 打印初始准确率
initial_acc = evaluate(test_data, net)
print(f"初始准确率: {initial_acc:.4f}")
# 4. 训练神经网络
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
for epoch in range(2):
net.train() # 设置为训练模式
total_loss = 0
for batch_idx, (x, y) in enumerate(train_data):
net.zero_grad()
output = net(x.view(-1, 28*28))
loss = torch.nn.functional.nll_loss(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
# 评估当前准确率
accuracy = evaluate(test_data, net)
print(f"Epoch {epoch} 完成,准确率: {accuracy:.4f}")
# 6. 保存模型
torch.save(net.state_dict(), 'model.pth')
print("模型已保存为 model.pth")
if name == 'main':
main()
2 识别
import os
import torch
import tkinter as tk
from tkinter import messagebox, filedialog
from PIL import Image, ImageDraw, ImageOps
from torchvision import transforms, datasets
解决OpenMP库冲突
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
1. 定义神经网络结构(必须与训练时完全一致)
class MNISTNet(torch.nn.Module):
def init(self):
super(MNISTNet, self).init()
self.fc1 = torch.nn.Linear(28*28, 64)
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 10)
def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
2. 加载模型并验证
def load_and_verify_model(model_path):
"""加载模型并使用MNIST测试集验证"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
model = MNISTNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# 验证模型是否正常工作(使用MNIST测试集的前5张图片)
test_dataset = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transforms.ToTensor()
)
# 取前5张测试图片
test_images = [test_dataset[i][0] for i in range(5)]
test_labels = [test_dataset[i][1] for i in range(5)]
with torch.no_grad():
correct = 0
for img, label in zip(test_images, test_labels):
output = model(img.view(-1, 28*28))
pred = torch.argmax(output).item()
if pred == label:
correct += 1
if correct < 3: # 至少正确识别3张才认为模型有效
raise RuntimeError(f"模型验证失败,仅正确识别{correct}/5张测试图片,可能模型文件损坏或训练不充分")
print(f"模型验证成功,正确识别{correct}/5张测试图片")
return model
3. 手写识别界面(修复预处理问题)
class DrawingInterface:
def init(self, root, model):
self.root = root
self.root.title("手写数字识别(修复版)")
self.model = model
# 画布设置(280x280,便于缩放至28x28)
self.canvas_size = 280
self.canvas = tk.Canvas(
root,
width=self.canvas_size,
height=self.canvas_size,
bg="white",
cursor="cross"
)
self.canvas.pack(pady=20)
# 创建图像对象(初始为白色背景)
self.image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
self.draw = ImageDraw.Draw(self.image)
# 绑定鼠标事件
self.canvas.bind("<B1-Motion>", self.draw_stroke)
self.canvas.bind("<ButtonRelease-1>", self.reset_last_pos)
# 按钮区域
self.buttons_frame = tk.Frame(root)
self.buttons_frame.pack(pady=10)
# 识别按钮
self.recognize_btn = tk.Button(
self.buttons_frame,
text="识别数字",
command=self.recognize_digit,
width=15,
font=("Arial", 12)
)
self.recognize_btn.grid(row=0, column=0, padx=10)
# 清除按钮
self.clear_btn = tk.Button(
self.buttons_frame,
text="清除画布",
command=self.clear_canvas,
width=15,
font=("Arial", 12)
)
self.clear_btn.grid(row=0, column=1, padx=10)
# 测试按钮(用标准图片测试)
self.test_btn = tk.Button(
self.buttons_frame,
text="测试模型",
command=self.test_model,
width=15,
font=("Arial", 12)
)
self.test_btn.grid(row=0, column=2, padx=10)
# 结果显示
self.result_var = tk.StringVar()
self.result_var.set("请在画布上书写数字(0-9),然后点击识别")
self.result_label = tk.Label(
root,
textvariable=self.result_var,
font=("Arial", 14)
)
self.result_label.pack(pady=20)
# 鼠标位置记录
self.last_x = None
self.last_y = None
self.last_y = None
def draw_stroke(self, event):
"""绘制线条(确保笔迹粗细适中)"""
current_x, current_y = event.x, event.y
if self.last_x and self.last_y:
# 绘制线条(宽度15,圆润笔触)
self.canvas.create_line(
self.last_x, self.last_y, current_x, current_y,
width=15, fill="black", capstyle=tk.ROUND, smooth=tk.TRUE
)
self.draw.line(
[(self.last_x, self.last_y), (current_x, current_y)],
fill=0, # 黑色(与MNIST一致)
width=15
)
self.last_x = current_x
self.last_y = current_y
def reset_last_pos(self, event):
self.last_x = None
self.last_y = None
def clear_canvas(self):
self.canvas.delete("all")
self.image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
self.draw = ImageDraw.Draw(self.image)
self.result_var.set("请在画布上书写数字(0-9),然后点击识别")
def preprocess_image(self, image):
"""图像预处理(关键修复)"""
# 1. 缩放到28x28(与MNIST尺寸一致)
img = image.resize((28, 28), Image.LANCZOS)
# 2. 确保是白底黑字(MNIST标准)
# 反转图像(如果背景是黑色)
img = ImageOps.invert(img) if img.getextrema()[0] < 128 else img
# 3. 归一化(使用MNIST的均值和标准差)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 必须与训练时一致
])
return transform(img).unsqueeze(0) # 增加批次维度
def recognize_digit(self):
"""识别手写数字(修复版)"""
try:
# 预处理图像
image_tensor = self.preprocess_image(self.image)
# 模型预测
with torch.no_grad():
output = self.model(image_tensor.view(-1, 28*28))
probabilities = torch.exp(output) # 转换为概率
confidence, predicted_idx = torch.max(probabilities, 1)
# 输出所有类别的概率(用于调试)
print("各数字概率:", [f"{i}: {p:.2%}" for i, p in enumerate(probabilities[0])])
# 显示结果
result = f"识别结果: {predicted_idx.item()},置信度: {confidence.item():.2%}"
self.result_var.set(result)
except Exception as e:
messagebox.showerror("识别错误", f"错误: {str(e)}")
def test_model(self):
"""用MNIST测试集验证模型是否正常工作"""
try:
# 加载MNIST测试集
test_dataset = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transforms.ToTensor()
)
# 随机选一张图片测试
import random
idx = random.randint(0, 1000)
img, label = test_dataset[idx]
# 预测
with torch.no_grad():
output = self.model(img.view(-1, 28*28))
pred = torch.argmax(output).item()
result = f"模型测试: 实际数字{label},预测结果{pred},{'正确' if pred == label else '错误'}"
self.result_var.set(result)
except Exception as e:
messagebox.showerror("测试错误", f"错误: {str(e)}")
4. 主函数
def main():
# 模型路径(请替换为你的模型文件路径)
model_path = "model.pth"
try:
# 加载并验证模型
print(f"正在加载模型: {model_path}")
model = load_and_verify_model(model_path)
print("模型加载成功")
# 启动界面
root = tk.Tk()
app = DrawingInterface(root, model)
root.mainloop()
except Exception as e:
print(f"初始化失败: {str(e)}")
messagebox.showerror("初始化错误", f"无法启动应用: {str(e)}")
if name == "main":
main()
浙公网安备 33010602011771号