手写汉字识别
-- coding: utf-8 --
"""
Created on Thu Nov 13 17:43:36 2025
@author: HP
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from PIL import Image
def classes_txt(root, out_path, num_class=None):
dirs = os.listdir(root)
if not num_class:
num_class = len(dirs)
if not os.path.exists(os.path.dirname(out_path)):
os.makedirs(os.path.dirname(out_path))
with open(out_path, 'a+') as f:
f.seek(0)
lines = f.readlines()
try:
last_line = lines[-1].strip()
end = int(last_line.split('\')[-2]) + 1
except:
end = 0
if end < num_class:
dirs.sort()
dirs = dirs[end:num_class]
for dir in dirs:
dir_path = os.path.join(root, dir)
files = os.listdir(dir_path)
for file in files:
f.write(os.path.join(dir_path, file) + '\n')
class MyDataset(Dataset):
def init(self, txt_path, num_class, transforms=None):
super(MyDataset, self).init()
self.images = []
self.labels = []
with open(txt_path, 'r') as f:
for line in f:
line = line.strip()
label = int(line.split('\')[-2])
if label < num_class:
self.images.append(line)
self.labels.append(label)
self.transforms = transforms
def __getitem__(self, index):
image = Image.open(self.images[index]).convert('RGB')
label = self.labels[index]
if self.transforms is not None:
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.labels)
root = '../data'
classes_txt(root + '/train', root + '/train.txt')
classes_txt(root + '/test', root + '/test.txt')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.Grayscale(),
transforms.ToTensor()
])
num_class = 100
train_set = MyDataset(root + '/train.txt', num_class=num_class, transforms=transform)
test_set = MyDataset(root + '/test.txt', num_class=num_class, transforms=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
class MYNET(nn.Module):
def init(self):
super(MYNET, self).init()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, num_class)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model = MYNET().to(device)
summary(model, (1, 64, 64))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()
EPOCH = 10
for epoch in range(EPOCH):
model.train()
train_loss = 0.0
for step, (x, y) in enumerate(train_loader):
images, labels = x.to(device), y.to(device)
outputs = model(images)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
if step % 50 == 0:
print(f'第 {epoch+1} 轮, 第 {step+1} 步, 训练损失: {loss.item():.4f}')
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
images, labels = x.to(device), y.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
print(f'第 {epoch+1} 轮结束, 测试准确率: {acc:.2f}%\n')
torch.save(model.state_dict(), '../tmp/model.pkl')
print('模型已保存')
model = MYNET()
model.load_state_dict(torch.load('../tmp/model.pkl'))
model.eval()
def predict_image(image_path):
img = Image.open(image_path).convert('RGB')
img = transform(img)
img = img.view(1, 1, 64, 64)
with torch.no_grad():
output = model(img)
_, prediction = torch.max(output, 1)
return prediction.item()
result = predict_image('../data/test/00008/816.png')
print(f'预测类别: {result}')

浙公网安备 33010602011771号