pytorch 手写汉字识别

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchsummary import summary
from torch.optim.lr_scheduler import StepLR # 学习率衰减

确保中文路径正常读取

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]

1. 自定义数据集加载类

class ChineseCharDataset(Dataset):
def init(self, root_dir, num_class, transforms=None):
super().init()
self.images = [] # 存储照片路径
self.labels = [] # 存储照片对应的类别
self.transforms = transforms

    # 验证根目录是否存在
    if not os.path.exists(root_dir):
        raise ValueError(f"数据集目录不存在: {root_dir}")
    
    # 遍历root_dir下的类别文件夹
    class_folders = sorted([f for f in os.listdir(root_dir) 
                           if os.path.isdir(os.path.join(root_dir, f))])
    
    # 检查是否有类别文件夹
    if not class_folders:
        raise ValueError(f"在 {root_dir} 中未找到任何类别文件夹")
        
    # 只取前num_class个类别
    selected_classes = class_folders[:num_class]
    
    # 遍历每个类别文件夹,收集照片路径和标签
    for cls_name in selected_classes:
        cls_dir = os.path.join(root_dir, cls_name)
        # 遍历文件夹下所有照片文件
        img_files = [f for f in os.listdir(cls_dir) 
                    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
        
        if not img_files:  # 跳过空文件夹
            print(f"警告: 类别文件夹 {cls_dir} 中未找到图片文件,已跳过")
            continue
            
        for img_name in img_files:
            img_path = os.path.join(cls_dir, img_name)
            self.images.append(img_path)
            try:
                self.labels.append(int(cls_name))  # 类别名转数字作为标签
            except ValueError:
                raise ValueError(f"类别文件夹名 {cls_name} 不是数字,无法转换为标签")

def __getitem__(self, index):
    # 读取并处理照片(失败时自动重试下一张)
    try:
        img = Image.open(self.images[index]).convert('RGB')
    except Exception as e:
        print(f"照片读取失败: {self.images[index]}, 错误: {e}")
        return self.__getitem__((index + 1) % len(self))
    
    label = self.labels[index]
    if self.transforms:
        img = self.transforms(img)
    return img, label

def __len__(self):
    return len(self.images)

2. 优化后的CNN网络模型

class CharRecognitionNet(nn.Module):
def init(self, num_classes=100):
super().init()
# 卷积块:加深网络+BatchNorm+池化
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), # 1→32通道,3x3卷积(补边保尺寸)
nn.BatchNorm2d(32), # 标准化:加速收敛
nn.ReLU(),
nn.MaxPool2d(2, 2) # 2x2池化,尺寸减半
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1), # 32→64通道
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.conv3 = nn.Sequential( # 新增第3个卷积块
nn.Conv2d(64, 128, 3, padding=1),# 64→128通道
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)

    # 全连接层输入维度:128通道 × 16×16尺寸(128/2^3=16)
    self.fc_input_dim = 128 * 16 * 16
    self.fc1 = nn.Sequential(
        nn.Linear(self.fc_input_dim, 512),
        nn.ReLU(),
        nn.Dropout(0.5)                  # Dropout防过拟合
    )
    self.fc2 = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(0.3)
    )
    self.fc3 = nn.Linear(256, num_classes)  # 输出层:匹配类别数

def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = x.view(-1, self.fc_input_dim)  # 展平为全连接层输入
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

3. 标签平滑损失(提升泛化能力)

class LabelSmoothingLoss(nn.Module):
def init(self, num_classes, smoothing=0.1):
super().init()
self.num_classes = num_classes
self.smoothing = smoothing
self.confidence = 1.0 - smoothing

def forward(self, logits, labels):
    logits = logits.log_softmax(dim=1)  # 对数softmax
    # 构建平滑标签
    one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
    smooth_label = one_hot * self.confidence + (1 - one_hot) * (self.smoothing / (self.num_classes - 1))
    # 计算交叉熵损失
    loss = (-smooth_label * logits).mean(dim=1).sum()
    return loss

4. 计算全量数据集准确率

def calculate_full_accuracy(model, dataloader, device):
model.eval()
total_correct = 0
total_samples = 0

if len(dataloader.dataset) == 0:
    print("警告: 数据集为空,无法计算准确率")
    return 0.0, 0, 0
    
with torch.no_grad():  # 关闭梯度计算,加速推理
    for photos, labels in dataloader:
        photos, labels = photos.to(device), labels.to(device)
        outputs = model(photos)
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

if total_samples == 0:
    print("警告: 未找到任何样本,无法计算准确率")
    return 0.0, 0, 0
    
accuracy = total_correct / total_samples
model.train()  # 恢复训练模式
return accuracy, total_correct, total_samples

5. 主函数(完整逻辑)

def main():
# 基础配置(根据你的环境调整)
root = "D:\pytorch\data"
train_photo_dir = os.path.join(root, "train")
test_photo_dir = os.path.join(root, "test")
num_class = 100 # 你的汉字类别数
batch_size = 32 # 批次大小(GPU显存不足可改16)
epochs = 15 # 训练轮次
lr = 0.001 # 初始学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}\n")

# 数据增强(训练集加扰动,测试集仅预处理)
train_transform = transforms.Compose([
    transforms.Resize((128, 128)),    # 尺寸从64→128,保留更多细节
    transforms.Grayscale(),           # 转灰度图
    transforms.RandomRotation(10),    # 随机旋转±10度
    transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 随机平移±10%
    transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),# 随机缩放80%-100%
    transforms.GaussianBlur(kernel_size=(3,3), sigma=(0.1, 0.5)),# 高斯模糊
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集(带异常处理)
print("正在加载照片数据集...")
try:
    train_dataset = ChineseCharDataset(
        root_dir=train_photo_dir,
        num_class=num_class,
        transforms=train_transform
    )
    test_dataset = ChineseCharDataset(
        root_dir=test_photo_dir,
        num_class=num_class,
        transforms=test_transform
    )
except Exception as e:
    print(f"数据集加载失败: {e}")
    return

# 验证数据集非空
if len(train_dataset) == 0:
    print(f"错误: 训练集目录 {train_photo_dir} 中未找到有效图片")
    return
if len(test_dataset) == 0:
    print(f"错误: 测试集目录 {test_photo_dir} 中未找到有效图片")
    return

# 创建数据加载器
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0,  # Windows设0,Linux/Mac可设4
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=64, 
    shuffle=False, 
    num_workers=0,
    pin_memory=True
)
print(f"加载完成:训练照片{len(train_dataset)}张,测试照片{len(test_dataset)}张\n")

# 初始化模型、损失函数、优化器
model = CharRecognitionNet(num_classes=num_class).to(device)
summary(model, (1, 128, 128))  # 打印网络结构

criterion = LabelSmoothingLoss(num_classes=num_class, smoothing=0.1)  # 标签平滑
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=lr, 
    weight_decay=1e-5  # 权重衰减防过拟合
)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # 每5轮学习率×0.5

# 模型训练(带最优模型保存)
print("\n开始训练模型...")
best_test_acc = 0.0  # 记录最优测试准确率
for epoch in range(epochs):
    model.train()
    total_train_loss = 0.0
    total_train_correct = 0
    total_train_samples = 0

    for step, (photos, labels) in enumerate(train_loader):
        photos, labels = photos.to(device), labels.to(device)
        batch_size_current = photos.size(0)
        
        # 前向传播
        outputs = model(photos)
        loss = criterion(outputs, labels)
        
        # 计算准确率
        _, predicted = torch.max(outputs, 1)
        batch_correct = (predicted == labels).sum().item()
        total_train_correct += batch_correct
        total_train_samples += batch_size_current
        
        # 反向传播更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()

        # 每20批次打印日志
        if (step + 1) % 20 == 0:
            avg_loss = total_train_loss / (step + 1)
            avg_train_acc = total_train_correct / total_train_samples if total_train_samples > 0 else 0
            test_acc, _, _ = calculate_full_accuracy(model, test_loader, device)
            
            print(f"【轮次 {epoch+1}/{epochs} | 批次 {step+1} | 学习率 {scheduler.get_last_lr()[0]:.6f}】")
            print(f"  - 训练平均损失: {avg_loss:.4f}")
            print(f"  - 训练平均准确率: {avg_train_acc:.4f} (正确{total_train_correct}/{total_train_samples})")
            print(f"  - 测试集准确率: {test_acc:.4f}\n")

    # 轮次结束处理
    scheduler.step()  # 学习率衰减生效
    final_train_acc, train_corr, train_total = calculate_full_accuracy(model, train_loader, device)
    final_test_acc, test_corr, test_total = calculate_full_accuracy(model, test_loader, device)
    print(f"=== 轮次 {epoch+1} 训练结束 ===")
    print(f"训练集最终准确率: {final_train_acc:.4f} (正确{train_corr}/{train_total})")
    print(f"测试集最终准确率: {final_test_acc:.4f} (正确{test_corr}/{test_total})\n")

    # 保存最优模型
    if final_test_acc > best_test_acc:
        best_test_acc = final_test_acc
        save_dir = os.path.join(root, "tmp")
        os.makedirs(save_dir, exist_ok=True)
        best_model_path = os.path.join(save_dir, "best_char_recognition_model.pkl")
        torch.save(model.state_dict(), best_model_path)
        print(f"✅ 最优模型已更新: {best_model_path} (当前最优测试准确率: {best_test_acc:.4f})\n")

# 单张照片预测(使用最优模型)
print("=== 单张照片预测测试 ===")
best_model = CharRecognitionNet(num_classes=num_class)
best_model_path = os.path.join(root, "tmp", "best_char_recognition_model.pkl")

try:
    # 加载最优模型权重
    if not os.path.exists(best_model_path):
        raise FileNotFoundError(f"最优模型文件不存在: {best_model_path}")
    best_model.load_state_dict(torch.load(best_model_path, map_location=device))
    best_model.to(device)
    best_model.eval()  # 切换到推理模式
    
    # 测试图片路径(请替换为你的实际测试图片路径)
    test_single_photo = "D:\\pytorch\\data\\test\\7\\0620.png"
    if not os.path.exists(test_single_photo):
        raise FileNotFoundError(f"预测图片不存在: {test_single_photo}")
    
    # 图片预处理(与测试集一致)
    img = Image.open(test_single_photo).convert('RGB')
    img = test_transform(img).to(device)
    img = img.unsqueeze(0)  # 增加batch维度(模型要求输入格式:[batch, channel, H, W])
    
    # 推理预测
    with torch.no_grad():
        output = best_model(img)
        _, predicted_cls = torch.max(output, 1)  # 取概率最大的类别
        confidence = F.softmax(output, dim=1)[0][predicted_cls].item()  # 计算置信度
    
    # 打印预测结果
    print(f"预测图片路径: {test_single_photo}")
    print(f"预测汉字类别: {predicted_cls.item()}")
    print(f"预测置信度: {confidence:.4f} (值越高,预测越可靠)")

except Exception as e:
    print(f"单张照片预测失败: {e}")

# 防止Windows控制台闪退(可选,根据需要保留)
print("\n按任意键退出...")
input()

if name == "main":
main()
屏幕截图 2025-10-30 234224

posted @ 2025-10-30 23:49  piuky  阅读(16)  评论(0)    收藏  举报