手写汉字识别

点击查看代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import random
from PIL import Image, ImageDraw, ImageFont
import numpy as np


# -------------------------- 1. 自动生成手写汉字数据集(兼容所有Pillow版本) --------------------------
class ChineseHandwritingDatasetGenerator:
    def __init__(self, save_root="./generated_chinese_mnist", num_classes=15,
                 train_samples=3000, test_samples=600, img_size=64):
        self.save_root = save_root
        self.num_classes = num_classes
        self.train_samples = train_samples
        self.test_samples = test_samples
        self.img_size = img_size
        self.classes_per_sample = train_samples // num_classes
        self.test_per_sample = test_samples // num_classes

        # 15个常用汉字类别
        self.CHINESE_CHARACTERS = [
            '零', '一', '二', '三', '四', '五', '六', '七', '八', '九',
            '十', '百', '千', '万', '亿'
        ]

        # 配置中文字体(Windows系统默认支持)
        self.font = self._get_chinese_font()

    def _get_chinese_font(self, font_size=40):
        """获取Windows系统默认中文字体(兼容所有Pillow版本)"""
        windows_fonts = [
            "C:/Windows/Fonts/simhei.ttf",  # 黑体(优先)
            "C:/Windows/Fonts/msyh.ttc",  # 微软雅黑
            "C:/Windows/Fonts/simsun.ttc"  # 宋体
        ]

        for font_path in windows_fonts:
            if os.path.exists(font_path):
                try:
                    return ImageFont.truetype(font_path, font_size)
                except Exception as e:
                    print(f"加载字体 {font_path} 失败:{e}")
                    continue

        raise Exception("未找到可用的中文字体!请检查Windows字体目录是否存在simhei.ttf等文件。")

    def _generate_handwriting_char(self, char):
        """生成手写风格汉字图像(无版本依赖)"""
        # 创建白色背景图像(RGBA格式)
        img = Image.new("RGBA", (self.img_size, self.img_size), (255, 255, 255, 255))
        draw = ImageDraw.Draw(img)

        # 随机字体大小(32-48),模拟手写粗细差异
        font_size = random.randint(32, 48)
        font = self._get_chinese_font(font_size)

        # 计算文字居中位置(兼容所有Pillow版本)
        try:
            char_bbox = draw.textbbox((0, 0), char, font=font)
            char_width = char_bbox[2] - char_bbox[0]
            char_height = char_bbox[3] - char_bbox[1]
        except AttributeError:
            char_width, char_height = draw.textsize(char, font=font)
        x = (self.img_size - char_width) // 2 + random.randint(-3, 3)  # 随机偏移
        y = (self.img_size - char_height) // 2 + random.randint(-3, 3)

        # 随机深灰色文字(模拟手写笔颜色)
        gray = random.randint(0, 60)
        draw.text((x, y), char, font=font, fill=(gray, gray, gray, 255))

        # 1. 随机旋转(-5°~5°)
        angle = random.randint(-5, 5)
        img = img.rotate(angle, expand=False, fillcolor=(255, 255, 255, 255))

        # 2. 添加随机噪声(模拟手写粗糙感)
        img_np = np.array(img)
        noise = np.random.normal(0, 8, size=img_np.shape[:2]).astype(np.int16)
        img_np[..., :3] = np.clip(img_np[..., :3] + noise[..., None], 0, 255)
        img = Image.fromarray(img_np.astype(np.uint8))

        # 3. 随机剪切(手动构造仿射矩阵,无版本依赖)
        shear_angle = random.uniform(-0.1, 0.1)
        # 仿射矩阵参数:(a, b, c, d, e, f) → 对应水平剪切
        affine_matrix = (1.0, shear_angle, 0.0, 0.0, 1.0, 0.0)
        img = img.transform(
            size=(self.img_size, self.img_size),
            method=Image.AFFINE,
            data=affine_matrix,
            resample=Image.BILINEAR if hasattr(Image, 'BILINEAR') else Image.NEAREST,
            fillcolor=(255, 255, 255)
        )

        # 转为RGB格式(去除透明度,适配模型输入)
        return img.convert("RGB")

    def generate_and_save(self):
        """生成并保存train/test数据集"""
        # 创建目录结构
        for split in ["train", "test"]:
            split_root = os.path.join(self.save_root, split)
            for class_idx in range(self.num_classes):
                os.makedirs(os.path.join(split_root, str(class_idx)), exist_ok=True)

        # 生成训练集
        print(f"正在生成训练集({self.train_samples}个样本)...")
        for class_idx, char in enumerate(self.CHINESE_CHARACTERS):
            for i in range(self.classes_per_sample):
                img = self._generate_handwriting_char(char)
                save_path = os.path.join(
                    self.save_root, "train", str(class_idx), f"char_{class_idx}_{i}.png"
                )
                img.save(save_path)

        # 生成测试集
        print(f"正在生成测试集({self.test_samples}个样本)...")
        for class_idx, char in enumerate(self.CHINESE_CHARACTERS):
            for i in range(self.test_per_sample):
                img = self._generate_handwriting_char(char)
                save_path = os.path.join(
                    self.save_root, "test", str(class_idx), f"char_{class_idx}_{i}.png"
                )
                img.save(save_path)

        print(f"数据集生成完成!保存路径:{self.save_root}")


# -------------------------- 2. 生成数据集并加载 --------------------------
try:
    generator = ChineseHandwritingDatasetGenerator(
        save_root="./generated_chinese_mnist",
        train_samples=3000,  # 每个类别200个训练样本
        test_samples=600  # 每个类别40个测试样本
    )
    generator.generate_and_save()
except Exception as e:
    print(f"数据集生成失败:{e}")
    exit(1)

CHINESE_CHARACTERS = generator.CHINESE_CHARACTERS

# 数据预处理(适配64x64三通道图像)
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
data_root = "./generated_chinese_mnist"
train_dataset = datasets.ImageFolder(os.path.join(data_root, "train"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_root, "test"), transform=transform)

# 数据加载器(Windows系统num_workers=0避免多线程报错)
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)


# -------------------------- 3. 汉字识别CNN模型 --------------------------
class ChineseCharacterCNN(nn.Module):
    def __init__(self, num_classes=15):
        super().__init__()
        self.conv_layers = nn.Sequential(
            # 卷积块1:3→64通道,64x64→32x32
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # 卷积块2:64→128通道,32x32→16x16
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # 卷积块3:128→256通道,16x16→8x8
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


# 设备配置(GPU优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChineseCharacterCNN(num_classes=15).to(device)

# -------------------------- 4. 模型训练与测试 --------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
num_epochs = 15
best_acc = 0.0

print(f"\n开始训练(设备:{device})...")
for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    total_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    scheduler.step()  # 学习率衰减

    # 测试阶段
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    acc = 100 * correct / total
    print(f"Epoch [{epoch + 1:2d}/{num_epochs}], Loss: {avg_loss:.4f}, Test Acc: {acc:.2f}%")

    # 保存最佳模型
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_chinese_cnn.pth")
        print(f"  → 保存最佳模型(准确率:{best_acc:.2f}%)")

print(f"\n训练完成!最佳测试准确率:{best_acc:.2f}%")

# -------------------------- 5. 可视化测试结果 --------------------------
plt.rcParams['font.sans-serif'] = ['SimHei']  # 中文显示
plt.rcParams['axes.unicode_minus'] = False

model.eval()
dataiter = iter(test_loader)
imgs, labels = next(dataiter)
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, preds = torch.max(outputs, 1)

# 转换为CPU用于绘图
imgs = imgs.cpu()
labels = labels.cpu()
preds = preds.cpu()

# 绘制前6个样本
fig, axes = plt.subplots(1, 6, figsize=(15, 5))
for i in range(6):
    # 反归一化图像
    img = imgs[i].permute(1, 2, 0)
    img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
    img = torch.clamp(img, 0, 1)

    axes[i].imshow(img)
    true_char = CHINESE_CHARACTERS[labels[i]]
    pred_char = CHINESE_CHARACTERS[preds[i]]
    axes[i].set_title(f"真实:{true_char}\n预测:{pred_char}", fontsize=12)
    axes[i].axis('off')

plt.tight_layout()
plt.show()


# -------------------------- 6. 单张手写汉字预测(可选) --------------------------
def predict_single_img(img_path):
    """预测单张手写汉字图片"""
    transform_single = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    try:
        from PIL import Image
        img = Image.open(img_path).convert("RGB")
        img_tensor = transform_single(img).unsqueeze(0).to(device)

        model.eval()
        with torch.no_grad():
            output = model(img_tensor)
            _, pred = torch.max(output, 1)
            pred_char = CHINESE_CHARACTERS[pred.item()]

        # 显示结果
        plt.imshow(img)
        plt.title(f"预测结果:{pred_char}", fontsize=14)
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"预测失败:{e}")

# 示例:替换为你的手写汉字图片路径
# predict_single_img("my_handwriting.png")
posted @ 2025-11-06 17:56  四季歌镜  阅读(11)  评论(0)    收藏  举报