gxh6666

导航

汉字识别

import os
import random
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageFilter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import platform

===================== 配置参数 =====================

学号后四位

STUDENT_ID = "3029"

识别的汉字列表

CHARS = ['一', '二', '三', '人', '口', '手', '日', '月', '水']

图像尺寸

IMG_SIZE = 64

训练/测试集数量

TRAIN_NUM = 200
TEST_NUM = 50

训练参数

BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.005

数据保存目录

DATA_DIR = "hanzi_recognize_data"

噪声与干扰参数

NOISE_PROB = 0.2 # 噪点概率
BLUR_PROB = 0.15 # 模糊概率
ROTATE_RANGE = (-20, 20) # 旋转范围

解决OMP冲突

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

自动选择设备(GPU/CPU)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

===================== 工具函数 =====================

def get_chinese_font():
"""获取系统可用的中文字体,避免绘制汉字失败"""
font_paths = [
"simsun.ttc", # Windows 宋体
"msyh.ttc", # Windows 微软雅黑
"/System/Library/Fonts/PingFang.ttc", # macOS 苹方
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" # Linux 备用
]
for path in font_paths:
try:
return ImageFont.truetype(path, 40)
except:
continue
return ImageFont.load_default()

def print_student_info():
"""格式化输出学号信息"""
print("=" * 40)
print(f"汉字识别作业 - 学号后四位:{STUDENT_ID}")
print("=" * 40)

===================== 汉字图像生成器 =====================

class HanziImageGenerator:
def init(self):
self.font = get_chinese_font()
# 汉字基础绘制位置(适配不同字符)
self.char_pos = {
'一': (10, 25), '二': (10, 20), '三': (10, 15),
'人': (15, 20), '口': (15, 15), '手': (10, 15),
'日': (15, 15), '月': (12, 15), '水': (10, 15)
}

def add_noise(self, img):
    """为图像添加随机噪点"""
    img_array = np.array(img, dtype=np.uint8)
    noise = np.random.randint(0, 256, img_array.shape, dtype=np.uint8)
    mask = np.random.random(img_array.shape) < NOISE_PROB
    img_array[mask] = noise[mask]
    return Image.fromarray(img_array)

def generate_image(self, char):
    """生成单张带干扰的汉字图像"""
    # 创建白底灰度图
    img = Image.new("L", (IMG_SIZE, IMG_SIZE), 255)
    draw = ImageDraw.Draw(img)
    
    # 随机调整绘制位置
    x, y = self.char_pos.get(char, (10, 20))
    x += random.randint(-5, 5)
    y += random.randint(-5, 5)
    
    # 绘制汉字
    draw.text((x, y), char, font=self.font, fill=0)
    
    # 随机旋转
    rotate_angle = random.randint(*ROTATE_RANGE)
    img = img.rotate(rotate_angle, expand=False, fillcolor=255)
    
    # 随机添加噪点
    if random.random() < NOISE_PROB:
        img = self.add_noise(img)
    
    # 随机模糊
    if random.random() < BLUR_PROB:
        img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.3, 1.0)))
    
    return img

def generate_dataset(self):
    """生成训练/测试数据集"""
    # 清空旧数据
    if os.path.exists(DATA_DIR):
        for root, dirs, files in os.walk(DATA_DIR, topdown=False):
            for f in files:
                os.remove(os.path.join(root, f))
            for d in dirs:
                os.rmdir(os.path.join(root, d))
    
    # 创建目录
    for split in ["train", "test"]:
        for char in CHARS:
            os.makedirs(os.path.join(DATA_DIR, split, char), exist_ok=True)
    
    # 生成图像
    print("正在生成汉字数据集...")
    for char in CHARS:
        # 训练集
        for i in range(TRAIN_NUM):
            img = self.generate_image(char)
            img.save(os.path.join(DATA_DIR, "train", char, f"{i}.png"))
        # 测试集
        for i in range(TEST_NUM):
            img = self.generate_image(char)
            img.save(os.path.join(DATA_DIR, "test", char, f"{i}.png"))
    print(f"数据集生成完成,保存路径:{os.path.abspath(DATA_DIR)}")

===================== 数据集加载器 =====================

class HanziDataset(Dataset):
def init(self, split="train"):
self.split = split
self.data_path = os.path.join(DATA_DIR, split)
self.char2idx = {c: i for i, c in enumerate(CHARS)}
self.images, self.labels = self.load_data()

    # 图像预处理
    self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

def load_data(self):
    """加载图像路径和对应标签"""
    images = []
    labels = []
    for char in CHARS:
        char_dir = os.path.join(self.data_path, char)
        for img_name in os.listdir(char_dir):
            if img_name.endswith(".png"):
                images.append(os.path.join(char_dir, img_name))
                labels.append(self.char2idx[char])
    return images, labels

def __len__(self):
    return len(self.images)

def __getitem__(self, idx):
    img = Image.open(self.images[idx]).convert("L")
    return self.transform(img), self.labels[idx]

===================== 轻量化CNN模型 =====================

class HanziCNN(nn.Module):
def init(self, num_classes=9):
super().init()
# 特征提取层
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.2),
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.2)
)
# 分类层
self.classifier = nn.Sequential(
nn.Linear(16 * 16 * 16, 64),
nn.ReLU(),
nn.Linear(64, num_classes)
)

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)  # 展平
    x = self.classifier(x)
    return x

===================== 训练与识别主逻辑 =====================

def train_model():
"""训练汉字识别模型"""
# 生成数据集
generator = HanziImageGenerator()
generator.generate_dataset()

# 加载数据
num_workers = 0 if platform.system() == "Windows" else 2
train_dataset = HanziDataset("train")
test_dataset = HanziDataset("test")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)

# 初始化模型、损失函数、优化器
model = HanziCNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练过程
print(f"\n开始训练(使用{DEVICE})...")
best_acc = 0.0
for epoch in range(EPOCHS):
    # 训练模式
    model.train()
    total_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        
        # 前向传播
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * imgs.size(0)

    # 测试模式
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    # 计算指标
    avg_loss = total_loss / len(train_dataset)
    acc = 100 * correct / total
    print(f"第{epoch+1:2d}轮 | 损失:{avg_loss:.4f} | 测试准确率:{acc:.2f}%")

    # 保存最优模型
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "hanzi_best_model.pth")

    # 提前停止(准确率达标)
    if acc >= 95:
        print(f"准确率达到95%,提前结束训练")
        break

print(f"\n训练完成!最优测试准确率:{best_acc:.2f}%")
return model

def recognize_hanzi(model):
"""交互式汉字识别"""
print("\n===== 汉字识别功能 =====")
model.eval()
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

while True:
    path = input("\n请输入汉字图片路径(输入q退出):")
    if path.lower() == "q":
        break
    if not os.path.exists(path):
        print("路径错误,请重新输入!")
        continue

    try:
        # 加载并预处理图片
        img = Image.open(path).convert("L")
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        
        # 模型推理
        with torch.no_grad():
            output = model(img_tensor)
            pred_idx = torch.argmax(output, 1).item()
            pred_char = CHARS[pred_idx]
            confidence = torch.softmax(output, 1)[0, pred_idx].item() * 100
        
        # 输出识别结果(含学号)
        print(f"学号后四位:{STUDENT_ID} | 识别结果:{pred_char} | 可信度:{confidence:.2f}%")
    except Exception as e:
        print(f"识别失败:{str(e)}")

def main():
# 输出学号信息
print_student_info()
# 训练模型
model = train_model()
# 加载最优模型
model.load_state_dict(torch.load("hanzi_best_model.pth", map_location=DEVICE))
# 交互式识别
recognize_hanzi(model)
# 程序结尾再次输出学号
print("\n" + "=" * 40)
print(f"程序结束 - 学号后四位:{STUDENT_ID}")
print("=" * 40)

if name == "main":
main()

posted on 2025-12-26 12:05  吗假期  阅读(0)  评论(0)    收藏  举报