代码改变世界

文字识别系统

2025-11-19 22:27  nm1137  阅读(18)  评论(0)    收藏  举报

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

-------------------------- 1. 中文专用配置 --------------------------

class Config:
# 数据配置(中文核心)
img_height = 32 # 固定高度,中文识别常用32
img_width = 128 # 宽度增大(中文单字符宽度比英文宽)
# 3500+ 常用中文汉字 + 数字 + 英文(覆盖日常场景)
vocab = '''一丨丶丿乀乁乙二亍厂匚刂卜冂亻八人入勹冫几凵刀力勑匕匸十卜卩厂厶廴又口囗土士夂夊夕大女子宀寸小尢尸屮山巛工己巾干幺广廾弋弓彐彡彳心戈戸手支攴文斗斤方无日曰月木欠止歹殳毋比毛氏气水火爪父爻爿片牙牛犬玄玉瓜瓦甘生用田由甲申电白目矛矢石示禸禾穴立世皿皮癶矛耒老而耒耳臣西覀页至臼舌竹米糸缶网羊羽而耒肉臣自至臼舌竹米糸缶网羊羽而耒肉臣自至臼舌竹米糸缶网羊羽而耒肉臣自至臼舌竹米糸缶网羊羽
0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'''
num_classes = len(vocab) + 1 # +1 用于CTC的blank标签
max_len = 15 # 中文文本最大长度(比英文长)

# 模型配置(适配中文)
hidden_size = 512  # 增大隐藏层(中文语义更复杂)
num_layers = 2  # 双向LSTM

# 训练配置
epochs = 80  # 中文数据量大,增加训练轮数
batch_size = 32
lr = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = 'chinese_ocr_data'  # 中文数据集目录
checkpoint_path = 'chinese_crnn_ocr.pth'  # 模型保存路径

config = Config()

字符映射(中文+数字+英文)

char2idx = {char: idx+1 for idx, char in enumerate(config.vocab)} # blank=0
idx2char = {idx+1: char for idx, char in enumerate(config.vocab)}
idx2char[0] = ''

-------------------------- 2. 中文数据集自动生成(无需手动准备) --------------------------

def generate_chinese_image(text, save_path, img_size=(config.img_width, config.img_height)):
"""生成中文文本图像(白色背景+黑色文本)"""
img = Image.new('RGB', img_size, color='white')
draw = ImageDraw.Draw(img)
# 适配不同系统的中文字体(关键:避免中文显示乱码)
try:
# Windows系统:黑体
font = ImageFont.truetype('simhei.ttf', 22)
except:
try:
# Mac系统:苹方
font = ImageFont.truetype('/System/Library/Fonts/PingFang.ttc', 22)
except:
# Linux系统:文泉驿微米黑(需提前安装)
font = ImageFont.truetype('WenQuanYi-Zen Hei.ttf', 22)
# 中文文本居中(避免超出图像)
text_width, text_height = draw.textbbox((0,0), text, font=font)[2:]
x = (img_size[0] - text_width) / 2
y = (img_size[1] - text_height) / 2
draw.text((x, y), text, font=font, fill='black')
img.save(save_path)

def generate_chinese_dataset(train_num=1500, val_num=300):
"""生成中文训练集和验证集(随机选取中文汉字组合)"""
# 筛选纯中文字符(用于生成文本)
chinese_chars = [c for c in config.vocab if '\u4e00' <= c <= '\u9fff']
# 创建文件夹
train_dir = os.path.join(config.data_dir, 'train')
val_dir = os.path.join(config.data_dir, 'val')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# 生成训练集(文本长度2-8个中文汉字)
for i in range(train_num):
    text_length = np.random.randint(2, 9)
    text = ''.join(np.random.choice(chinese_chars, size=text_length))
    save_path = os.path.join(train_dir, f'{text}_{i}.png')
    generate_chinese_image(text, save_path)

# 生成验证集
for i in range(val_num):
    text_length = np.random.randint(2, 9)
    text = ''.join(np.random.choice(chinese_chars, size=text_length))
    save_path = os.path.join(val_dir, f'{text}_{i}.png')
    generate_chinese_image(text, save_path)

print(f"中文数据集生成完成!\n训练集:{train_num}张({train_dir})\n验证集:{val_num}张({val_dir})")

-------------------------- 3. 中文数据集加载器 --------------------------

class ChineseOCRDataset(Dataset):
def init(self, data_dir, transform=None, is_train=True):
self.data_dir = os.path.join(data_dir, 'train' if is_train else 'val')
# 读取所有.png图像(支持中文文件名)
self.image_paths = [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.png')]
self.transform = transform

def __len__(self):
    return len(self.image_paths)

def __getitem__(self, idx):
    # 读取图像(中文路径兼容)
    img_path = self.image_paths[idx]
    img = Image.open(img_path).convert('L')  # 转灰度图
    img = Image.eval(img, lambda x: 255 - x)  # 反相(白底黑字→黑底白字,提升识别率)
    
    # 从文件名提取中文标签(如"你好世界_123.png"→"你好世界")
    label = os.path.splitext(os.path.basename(img_path))[0].split('_')[0]  # 去除后缀和序号
    # 过滤无效字符(只保留vocab中的字符)
    label = ''.join([c for c in label if c in config.vocab])
    
    # 图像预处理
    if self.transform:
        img = self.transform(img)
    
    # 标签编码(中文→索引)+ padding到max_len
    label_indices = [char2idx[char] for char in label]
    label_length = len(label_indices)
    padded_label = torch.zeros(config.max_len, dtype=torch.long)
    padded_label[:label_length] = torch.tensor(label_indices)
    
    return img, padded_label, label_length

图像预处理(适配中文图像尺寸)

transform = transforms.Compose([
transforms.Resize((config.img_height, config.img_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # 灰度图归一化到[-1,1]
])

-------------------------- 4. 适配中文的CRNN模型(修复维度不匹配) --------------------------

class ChineseCRNN(nn.Module):
def init(self, num_classes, hidden_size, num_layers):
super(ChineseCRNN, self).init()
self.hidden_size = hidden_size
self.num_layers = num_layers

    # CNN特征提取(输入:1×32×128 → 输出:512×1×32)
    # 中文图像宽度增大到128,调整池化步长,确保时间步足够(32个时间步)
    self.cnn = nn.Sequential(
        # 卷积层1:1→64, 3×3, padding=1
        nn.Conv2d(1, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2),  # 输出:64×16×64
        
        # 卷积层2:64→128, 3×3, padding=1
        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2),  # 输出:128×8×32
        
        # 卷积层3:128→256, 3×3, padding=1
        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        
        # 卷积层4:256→256, 3×3, padding=1
        nn.Conv2d(256, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=(2,2), stride=(2,1)),  # 输出:256×4×32
        
        # 卷积层5:256→512, 3×3, padding=1
        nn.Conv2d(256, 512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        
        # 卷积层6:512→512, 3×3, padding=1
        nn.Conv2d(512, 512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=(2,2), stride=(2,1)),  # 输出:512×1×32
        
        # 卷积层7:512→512, 1×1(降维+调整通道)
        nn.Conv2d(512, 512, kernel_size=1),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),  # 最终输出:512×1×32(32个时间步)
    )
    
    # 双向LSTM(适配中文长序列)
    self.rnn = nn.LSTM(
        input_size=512,  # CNN输出的特征维度(每个时间步512维)
        hidden_size=hidden_size,
        num_layers=num_layers,
        bidirectional=True,
        batch_first=True,
        dropout=0.1  # 添加dropout防止过拟合(中文数据复杂)
    )
    
    # 全连接层(双向LSTM输出→字符类别)
    self.fc = nn.Linear(hidden_size * 2, num_classes)

def forward(self, x):
    # CNN特征提取:(batch, 1, 32, 128) → (batch, 512, 1, 32)
    cnn_out = self.cnn(x)
    
    # 修复维度不匹配:强制重塑为 (batch, 512, 32) → 再调整为 (batch, 32, 512)
    # 避免squeeze导致的维度不确定性,直接用view指定维度
    batch_size = x.size(0)
    rnn_in = cnn_out.view(batch_size, 512, -1).permute(0, 2, 1)  # (batch, 32, 512)
    
    # RNN序列建模:(batch, 32, 512) → (batch, 32, 2×512)
    rnn_out, _ = self.rnn(rnn_in)
    
    # 全连接层:(batch, 32, 1024) → (batch, 32, num_classes)
    output = self.fc(rnn_out)
    
    # CTC损失要求格式:(time_step, batch, num_classes)
    return output.permute(1, 0, 2)

初始化模型

model = ChineseCRNN(
num_classes=config.num_classes,
hidden_size=config.hidden_size,
num_layers=config.num_layers
).to(config.device)

-------------------------- 5. 损失函数和优化器 --------------------------

CTC损失(处理不定长中文序列)

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

优化器(Adam,适配中文训练)

optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=1e-5) # 添加权重衰减防过拟合

学习率调度器(移除verbose参数,兼容低版本PyTorch)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=8, factor=0.5)

-------------------------- 6. CTC中文解码函数 --------------------------

def ctc_chinese_decode(outputs, idx2char):
"""CTC贪心解码(适配中文,去除blank和连续重复字符)"""
outputs = outputs.log_softmax(2).detach().cpu().numpy()
batch_size = outputs.shape[1]
decoded_results = []

for i in range(batch_size):
    seq = outputs[:, i, :]
    pred_indices = np.argmax(seq, axis=1)  # 每个时间步选概率最大的字符
    # 解码逻辑:跳过blank,去除连续重复
    decoded = []
    prev_char = None
    for idx in pred_indices:
        char = idx2char[idx]
        if char != '<blank>' and char != prev_char:
            decoded.append(char)
            prev_char = char
    decoded_results.append(''.join(decoded))

return decoded_results

-------------------------- 7. 训练和验证函数 --------------------------

def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
for batch_idx, (images, labels, label_lengths) in enumerate(loader):
images = images.to(device)
labels = labels.to(device)

    # 前向传播
    outputs = model(images)  # (time_step=32, batch=32, num_classes)
    
    # CTC损失参数
    batch_size = images.size(0)
    input_lengths = torch.full((batch_size,), outputs.size(0), dtype=torch.long).to(device)  # 32
    target_lengths = label_lengths.to(device)
    
    # 计算损失
    loss = criterion(outputs.log_softmax(2), labels, input_lengths, target_lengths)
    
    # 反向传播+优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    total_loss += loss.item()
    
    # 打印进度
    if (batch_idx + 1) % 15 == 0:
        print(f'Batch [{batch_idx+1}/{len(loader)}], Loss: {loss.item():.4f}')

avg_loss = total_loss / len(loader)
return avg_loss

def validate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0

with torch.no_grad():
    for images, labels, label_lengths in loader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        batch_size = images.size(0)
        input_lengths = torch.full((batch_size,), outputs.size(0), dtype=torch.long).to(device)
        target_lengths = label_lengths.to(device)
        
        # 计算损失
        loss = criterion(outputs.log_softmax(2), labels, input_lengths, target_lengths)
        total_loss += loss.item()
        
        # 解码预测结果
        decoded = ctc_chinese_decode(outputs, idx2char)
        
        # 计算准确率(完全匹配才计数)
        for i in range(batch_size):
            true_label = ''.join([idx2char[idx.item()] for idx in labels[i][:target_lengths[i]]])
            pred_label = decoded[i]
            if true_label == pred_label:
                total_correct += 1
            total_samples += 1

avg_loss = total_loss / len(loader)
accuracy = total_correct / total_samples
return avg_loss, accuracy

-------------------------- 8. 训练主流程 --------------------------

def train_chinese_ocr():
# 第一步:生成中文数据集
generate_chinese_dataset()

# 第二步:创建数据集和数据加载器(Windows系统num_workers=0)
train_dataset = ChineseOCRDataset(config.data_dir, transform=transform, is_train=True)
val_dataset = ChineseOCRDataset(config.data_dir, transform=transform, is_train=False)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True, 
    num_workers=0 if os.name == 'nt' else 2  # Windows禁用多进程
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    shuffle=False, 
    num_workers=0 if os.name == 'nt' else 2
)

# 验证数据集是否加载成功
print(f"训练集样本数:{len(train_dataset)},验证集样本数:{len(val_dataset)}")
if len(train_dataset) == 0 or len(val_dataset) == 0:
    raise ValueError("数据集加载失败!请检查文件夹路径和图像格式")

# 第三步:开始训练
best_accuracy = 0.0
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(config.epochs):
    print(f'\nEpoch [{epoch+1}/{config.epochs}]')
    
    # 训练
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, config.device)
    train_losses.append(train_loss)
    
    # 验证
    val_loss, val_accuracy = validate(model, val_loader, criterion, config.device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    # 学习率调度
    scheduler.step(val_loss)
    
    # 保存最优模型
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), config.checkpoint_path)
        print(f'Saved best model (accuracy: {best_accuracy:.4f}) to {config.checkpoint_path}')
    
    # 打印epoch信息
    print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Val Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

-------------------------- 9. 中文推理函数(单张图像识别) --------------------------

def predict_chinese(image_path, model, transform, device):
"""单张中文图像推理"""
# 图像预处理
img = Image.open(image_path).convert('L')
img = Image.eval(img, lambda x: 255 - x) # 反相
img = transform(img).unsqueeze(0).to(device) # 添加batch维度

# 模型推理
model.eval()
with torch.no_grad():
    outputs = model(img)
    decoded = ctc_chinese_decode(outputs, idx2char)

return decoded[0]

-------------------------- 10. 运行入口 --------------------------

if name == 'main':
# 选择1:训练中文OCR模型(首次运行请执行)
train_chinese_ocr()

# 选择2:加载预训练模型推理(训练完成后注释上面的train_chinese_ocr(),取消下面注释)
# model.load_state_dict(torch.load(config.checkpoint_path, map_location=config.device))
# test_image_path = 'test_chinese.png'  # 你的中文测试图像路径
# pred_text = predict_chinese(test_image_path, model, transform, config.device)
# print(f'中文识别结果:{pred_text}')