手写汉字识别

点击查看代码
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import random

# ===================== 1. 全局配置(训练轮次改为20) =====================
class Config:
    root_dir = "D:/Pysch2/Pytorch/MSRA-TD500"  # 数据集根目录
    batch_size = 8
    epochs = 20  # 训练轮次
    lr = 1e-4
    num_workers = 0  # Windows下建议设为0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    char2idx = None
    idx2char = None
    num_classes = None
    # 合成文本池(可自定义扩展)
    text_pool = [
        "一二三四", "五六七八", "九十百千", "甲乙丙丁", "金木水火",
        "天地人和", "上下左右", "前后内外", "春夏秋冬", "东南西北",
        "ABCDE", "FGHIJ", "12345", "67890", "XYZUV",
        "测试文本", "识别训练", "流程验证", "数据合成", "模型测试"
    ]

# ===================== 2. 数据集加载(自动补充合成文本) =====================
class MSRA_TD500_Synth_Dataset(Dataset):
    def __init__(self, root_dir, is_train=True, transform=None):
        self.root_dir = root_dir
        self.is_train = is_train
        self.transform = transform
        self.data_dir = os.path.join(root_dir, "train" if is_train else "test")
        self.image_names = [f for f in os.listdir(self.data_dir) if f.endswith(('.jpg', '.JPG'))]
        random.seed(42)  # 固定随机种子,确保文本标签可复现

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # 加载图像
        img_name = self.image_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"图像文件不存在:{img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]

        # 生成标注文件路径
        img_base_name = os.path.splitext(img_name)[0]
        label_path = os.path.join(self.data_dir, f"{img_base_name}.gt")
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"标注文件不存在:{label_path}")

        # 解析.gt文件并补充合成文本
        boxes, texts = self.parse_gt_with_synth_text(label_path)

        # 裁剪第一个文本区域
        if len(boxes) > 0:
            box = boxes[0]
            x_min = min(box[::2])
            x_max = max(box[::2])
            y_min = min(box[1::2])
            y_max = max(box[1::2])
            x_min, x_max = max(0, x_min), min(w, x_max)
            y_min, y_max = max(0, y_min), min(h, y_max)
            crop_img = image[y_min:y_max, x_min:x_max]
        else:
            crop_img = image

        if self.transform:
            crop_img = self.transform(crop_img)

        return crop_img, texts[0] if texts else ""

    @staticmethod
    def parse_gt_with_synth_text(gt_path):
        """解析仅坐标的.gt文件,自动补充合成文本"""
        boxes = []
        texts = []
        with open(gt_path, 'r', encoding='gbk') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                parts = line.split(',')
                if len(parts) == 8:
                    try:
                        coords = list(map(int, parts))
                        boxes.append(coords)
                        texts.append(random.choice(Config.text_pool))
                    except ValueError:
                        continue
        return boxes, texts

# ===================== 3. 生成字符映射表(基于合成文本池) =====================
def generate_synth_char_map():
    char_set = set()
    for text in Config.text_pool:
        char_set.update(list(text))
    char_list = sorted(list(char_set))
    Config.char2idx = {char: idx+1 for idx, char in enumerate(char_list)}
    Config.char2idx['<blank>'] = 0
    Config.idx2char = {v: k for k, v in Config.char2idx.items()}
    Config.num_classes = len(Config.char2idx)
    print(f"合成字符集大小:{Config.num_classes}(含空白字符)")
    print(f"字符集内容:{char_list}")

# ===================== 4. 数据预处理 =====================
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((32, 128)),
    transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((32, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ===================== 5. CRNN 模型定义(修复LSTM元组错误) =====================
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(512, 512, kernel_size=2, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        # 单独定义LSTM层,避免Sequential包装
        self.lstm1 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        conv_out = self.conv(x)
        conv_out = conv_out.squeeze(2)
        conv_out = conv_out.permute(0, 2, 1)
        # 手动处理LSTM输出(仅取序列输出)
        rnn_out, _ = self.lstm1(conv_out)
        rnn_out, _ = self.lstm2(rnn_out)
        logits = self.fc(rnn_out)
        return logits

# ===================== 6. 训练与测试工具函数 =====================
def encode_text(texts):
    labels = []
    lengths = []
    for text in texts:
        label = [Config.char2idx[c] for c in text if c in Config.char2idx]
        labels.extend(label)
        lengths.append(len(label))
    return torch.tensor(labels, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)

def ctc_decode(log_probs):
    outputs = []
    for prob in log_probs:
        pred = torch.argmax(prob, dim=1).cpu().numpy()
        text = []
        prev_char = None
        for c in pred:
            if c != 0 and c != prev_char:
                text.append(Config.idx2char[c])
                prev_char = c
        outputs.append(''.join(text))
    return outputs

# ===================== 7. 主训练流程 =====================
def main():
    generate_synth_char_map()
    if Config.num_classes < 2:
        print("错误:字符集为空,请扩展Config.text_pool!")
        return
    
    # 初始化数据集
    try:
        train_dataset = MSRA_TD500_Synth_Dataset(
            root_dir=Config.root_dir, is_train=True, transform=train_transform
        )
        test_dataset = MSRA_TD500_Synth_Dataset(
            root_dir=Config.root_dir, is_train=False, transform=test_transform
        )
    except Exception as e:
        print(f"数据集初始化失败:{e}")
        return

    # 数据加载器
    train_loader = DataLoader(
        train_dataset, batch_size=Config.batch_size, shuffle=True,
        num_workers=Config.num_workers, pin_memory=True, drop_last=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=Config.batch_size, shuffle=False,
        num_workers=Config.num_workers, pin_memory=True
    )
    
    # 模型、损失函数、优化器
    model = CRNN(num_classes=Config.num_classes).to(Config.device)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)
    
    # 训练循环
    print(f"开始训练(设备:{Config.device})")
    print(f"训练集样本数:{len(train_dataset)},测试集样本数:{len(test_dataset)}")
    for epoch in range(Config.epochs):
        model.train()
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{Config.epochs}]")
        
        for images, texts in pbar:
            images = images.to(Config.device)
            labels, label_lengths = encode_text(texts)
            labels = labels.to(Config.device)
            label_lengths = label_lengths.to(Config.device)
            
            # 前向传播
            logits = model(images)
            log_probs = F.log_softmax(logits, dim=2)
            input_lengths = torch.full(
                (logits.size(0),), logits.size(1), dtype=torch.long
            ).to(Config.device)
            
            # 计算损失
            if label_lengths.sum() == 0:
                continue
            loss = criterion(log_probs, labels, input_lengths, label_lengths)
            
            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item() * images.size(0)
            pbar.set_postfix({"Loss": loss.item()})
        
        # 学习率调度
        scheduler.step()
        
        # 打印训练信息
        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{Config.epochs}], Average Loss: {avg_loss:.4f}")
        
        # 每5个epoch测试一次
        if (epoch + 1) % 5 == 0:
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for images, texts in test_loader:
                    images = images.to(Config.device)
                    logits = model(images)
                    log_probs = F.log_softmax(logits, dim=2)
                    preds = ctc_decode(log_probs)
                    
                    # 打印部分预测结果
                    for i in range(min(3, len(preds))):
                        print(f"真实文本:{texts[i]},预测文本:{preds[i]}")
                    
                    # 计算准确率
                    for pred, text in zip(preds, texts):
                        if pred == text:
                            correct += 1
                        total += 1
            
            if total > 0:
                acc = correct / total
                print(f"Test Accuracy: {acc:.4f}")
            else:
                print("测试集无有效样本")
    
    # 保存模型
    torch.save(model.state_dict(), "crnn_msra_td500_synth_epoch20.pth")
    print("训练完成,20轮合成文本训练模型已保存!")

if __name__ == "__main__":
    main()
posted @ 2025-11-19 22:56  bolun123  阅读(0)  评论(0)    收藏  举报