手写汉字识别

-- coding: utf-8 --

"""
Created on Thu Nov 13 17:43:36 2025

@author: HP
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from PIL import Image

def classes_txt(root, out_path, num_class=None):
dirs = os.listdir(root)
if not num_class:
num_class = len(dirs)
if not os.path.exists(os.path.dirname(out_path)):
os.makedirs(os.path.dirname(out_path))
with open(out_path, 'a+') as f:
f.seek(0)
lines = f.readlines()
try:
last_line = lines[-1].strip()
end = int(last_line.split('\')[-2]) + 1
except:
end = 0

    if end < num_class:
        dirs.sort()
        dirs = dirs[end:num_class]
        for dir in dirs:
            dir_path = os.path.join(root, dir)
            files = os.listdir(dir_path)
            for file in files:
                f.write(os.path.join(dir_path, file) + '\n')

class MyDataset(Dataset):
def init(self, txt_path, num_class, transforms=None):
super(MyDataset, self).init()
self.images = []
self.labels = []
with open(txt_path, 'r') as f:
for line in f:
line = line.strip()
label = int(line.split('\')[-2])
if label < num_class:
self.images.append(line)
self.labels.append(label)
self.transforms = transforms

def __getitem__(self, index):
    image = Image.open(self.images[index]).convert('RGB')
    label = self.labels[index]
    if self.transforms is not None:
        image = self.transforms(image)
    return image, label

def __len__(self):
    return len(self.labels)

root = '../data'
classes_txt(root + '/train', root + '/train.txt')
classes_txt(root + '/test', root + '/test.txt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.Grayscale(),
transforms.ToTensor()
])

num_class = 100
train_set = MyDataset(root + '/train.txt', num_class=num_class, transforms=transform)
test_set = MyDataset(root + '/test.txt', num_class=num_class, transforms=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

class MYNET(nn.Module):
def init(self):
super(MYNET, self).init()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, num_class)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = self.pool(F.relu(self.conv3(x)))
    x = x.view(-1, 64 * 8 * 8)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

model = MYNET().to(device)
summary(model, (1, 64, 64))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

EPOCH = 10
for epoch in range(EPOCH):
model.train()
train_loss = 0.0
for step, (x, y) in enumerate(train_loader):
images, labels = x.to(device), y.to(device)

    outputs = model(images)
    loss = loss_func(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    train_loss += loss.item()
    
    if step % 50 == 0:
        print(f'第 {epoch+1} 轮, 第 {step+1} 步, 训练损失: {loss.item():.4f}')

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        images, labels = x.to(device), y.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

acc = 100 * correct / total
print(f'第 {epoch+1} 轮结束, 测试准确率: {acc:.2f}%\n')

torch.save(model.state_dict(), '../tmp/model.pkl')
print('模型已保存')

model = MYNET()
model.load_state_dict(torch.load('../tmp/model.pkl'))
model.eval()

def predict_image(image_path):
img = Image.open(image_path).convert('RGB')
img = transform(img)
img = img.view(1, 1, 64, 64)
with torch.no_grad():
output = model(img)
_, prediction = torch.max(output, 1)
return prediction.item()

result = predict_image('../data/test/00008/816.png')
print(f'预测类别: {result}')
image

posted @ 2025-11-13 18:38  YMH^_^  阅读(6)  评论(0)    收藏  举报