手写汉字识别

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from collections import defaultdict
--------------------------
1. 自定义数据集(仅保留指定汉字)
--------------------------
class CharDataset(Dataset):
def init(self, root_dir, keep_classes, transform=None):
self.root_dir = root_dir
self.transform = transform
self.keep_classes = keep_classes # 手动指定要保留的汉字
self.img_list = self._filter_images()
self.labels = [self._get_label(img) for img in self.img_list]
self.classes = sorted(list(set(self.labels)))
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
def _filter_images(self):
# 只保留指定类别的图片
filtered = []
for img_name in os.listdir(self.root_dir):
if img_name.lower().endswith(('.png', '.jpg')):
cls = img_name.split('_')[0]
if cls in self.keep_classes:
filtered.append(img_name)
return filtered
def _get_label(self, img_name):
return img_name.split('_')[0]
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img_name = self.img_list[idx]
img_path = os.path.join(self.root_dir, img_name)
img = Image.open(img_path).convert('L')
if self.transform:
img = self.transform(img)
label = self.class_to_idx[self._get_label(img_name)]
return img, label
--------------------------
2. 适配少类别模型(更专注学习)
--------------------------
class FocusedCharCNN(nn.Module):
def init(self, num_classes):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.fc = nn.Sequential(
nn.Linear(32 * 16 * 16, 64), # 进一步简化,专注少类别
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, num_classes)
)
def forward(self, x):
x = self.cnn(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
--------------------------
3. 主训练函数(仅训练指定汉字)
--------------------------
def main():
data_dir = "D:/py1/deeps/手写汉书" # 你的路径
keep_classes = ["人", "明"] # 手动指定只保留这两个汉字
img_size = 64
batch_size = 4
epochs = 60
lr = 0.0005
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据增强(适度即可,避免过度)
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomRotation(degrees=10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加载仅含指定类别的数据集
train_dataset = CharDataset(data_dir, keep_classes, transform)
test_dataset = CharDataset(data_dir, keep_classes, transform)
# 手动划分训练/测试集(确保每个类别都有样本)
def split_data(img_list, train_ratio=0.7):
np.random.shuffle(img_list)
train_size = int(len(img_list) * train_ratio)
return img_list[:train_size], img_list[train_size:]
all_imgs = train_dataset.img_list
train_imgs, test_imgs = split_data(all_imgs)
train_dataset.img_list = train_imgs
test_dataset.img_list = test_imgs
num_classes = len(keep_classes)
print(f"仅识别汉字:{keep_classes}(共{num_classes}个)")
print(f"训练集样本数:{len(train_dataset)},测试集样本数:{len(test_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model = FocusedCharCNN(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
best_acc = 0.0
for epoch in range(epochs):
model.train()
train_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * imgs.size(0)
train_loss /= len(train_dataset)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = 100 * correct / total
if test_acc > best_acc:
best_acc = test_acc
print(f"\nEpoch {epoch+1}/{epochs}")
print(f"训练损失:{train_loss:.4f} | 测试准确率:{test_acc:.2f}%")
print("-" * 60)
print(f"\n🎉 训练完成!")
print(f"📊 最终最佳测试准确率:{best_acc:.2f}%")
if best_acc >= 80:
print("✅ 达到目标准确率(80-90%)!")
else:
print("⚠️ 若仍未达标,可再补充“人”“明”的样本数(每个字5-8张不同写法)。")
if name == "main":
main()

浙公网安备 33010602011771号