手写汉字识别

点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image
import random

# ===================== 1. 配置项(全量最低配) =====================
DEVICE = torch.device("cpu")  # 强制CPU(最低配,无GPU也能跑)
IMG_SIZE = 32  # 图片缩到32x32(越小越省资源)
BATCH_SIZE = 8  # 批次大小降到8(低配内存友好)
EPOCHS = 5  # 仅训练5轮(快速验证)
LR = 0.01  # 学习率适中
# 类别数:HWDB1.1常用3755个一级汉字,这里简化为100类(测试用,可按需改)
NUM_CLASSES = 100  
# 数据集路径(需自行替换为HWDB1.1的图片路径)
HWDB_ROOT = "D:\Pysch2\Pytorch\HWDB1.1trn_gnt\HWDB1.1trm_nt"  

# ===================== 2. 数据集加载(极简版) =====================
class HWDB11Dataset(Dataset):
    def __init__(self, root, img_size=32, num_classes=100):
        self.root = root
        self.img_size = img_size
        self.num_classes = num_classes
        self.img_paths, self.labels = self._load_data()

    def _load_data(self):
        """极简加载:仅读取前num_classes类的少量样本"""
        img_paths = []
        labels = []
        # 遍历类别文件夹(HWDB1.1按汉字编码分文件夹)
        class_dirs = [d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
        # 仅取前num_classes类(降低复杂度)
        selected_classes = class_dirs[:self.num_classes]
        
        for label, cls_dir in enumerate(selected_classes):
            cls_path = os.path.join(self.root, cls_dir)
            # 每个类别仅取20张图(最低配,减少数据量)
            img_files = [f for f in os.listdir(cls_path) if f.endswith((".png", ".jpg"))][:20]
            for img_file in img_files:
                img_paths.append(os.path.join(cls_path, img_file))
                labels.append(label)
        return img_paths, labels

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # 极简预处理:灰度化+缩放+归一化
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("L")  # 转灰度图(省内存)
        img = img.resize((self.img_size, self.img_size))  # 缩放到固定尺寸
        # 转张量:(H,W) → (1,H,W),归一化到0-1(最低配预处理)
        img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0) / 255.0
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img_tensor, label

# ===================== 3. 极简CNN模型(最低配) =====================
class SimpleHWDBModel(nn.Module):
    def __init__(self, num_classes=100):
        super(SimpleHWDBModel, self).__init__()
        # 仅2个卷积层(最低配,减少参数)
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),  # 8个卷积核(极少)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32→16
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),  # 16个卷积核
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16→8
        )
        # 全连接层(参数极少)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 8 * 8, 64),  # 中间层仅64维
            nn.ReLU(),
            nn.Linear(64, num_classes)  # 输出类别
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# ===================== 4. 训练+验证(极简流程) =====================
def main():
    # 1. 加载数据集(最低配)
    dataset = HWDB11Dataset(HWDB_ROOT, IMG_SIZE, NUM_CLASSES)
    # 划分训练/测试集(极简:按8:2拆分)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # 2. 初始化模型(最低配)
    model = SimpleHWDBModel(NUM_CLASSES).to(DEVICE)
    criterion = nn.CrossEntropyLoss()  # 分类损失
    optimizer = optim.SGD(model.parameters(), lr=LR)  # SGD比Adam更省资源

    # 3. 极简训练
    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            
            # 前向+反向+优化
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 打印每轮损失(极简日志)
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}")

    # 4. 极简验证
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f"\n测试集准确率: {100 * correct / total:.2f}%")

if __name__ == "__main__":
    # 检查数据集路径(需自行替换为实际HWDB1.1路径)
    if not os.path.exists(HWDB_ROOT):
        print(f"错误:请将HWDB_ROOT改为你的HWDB1.1数据集路径,当前路径:{HWDB_ROOT}")
    else:
        main()
posted @ 2025-11-19 22:57  bolun123  阅读(12)  评论(0)    收藏  举报