文字识别

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import string
import random
from tqdm import tqdm

---------------------------- 配置参数 ----------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {DEVICE}")

字符集(0-9, a-z, 空白符)

CHARS = ' ' + string.digits + string.ascii_lowercase # 空白符索引为0
NUM_CLASSES = len(CHARS)

图像尺寸

DETECT_INPUT_SIZE = (512, 512) # 文本检测输入尺寸
RECOG_INPUT_SIZE = (100, 32) # 文本识别输入尺寸(宽, 高)

训练参数

BATCH_SIZE = 8
DETECT_LR = 1e-4
RECOG_LR = 1e-3
EPOCHS = 10

---------------------------- 1. 数据集定义 ----------------------------

class OCRDataset(Dataset):
"""OCR数据集(包含文本检测和识别的标注)"""
def init(self, root_dir, is_train=True):
self.root_dir = root_dir
self.is_train = is_train
self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
if f.endswith(('.jpg', '.png'))]
# 实际使用时需替换为真实标注文件(这里用随机数据模拟)

def __len__(self):
    return len(self.image_paths)

def __getitem__(self, idx):
    # 读取图像
    img_path = self.image_paths[idx]
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h, w = image.shape[:2]

    # 模拟文本检测标注(实际需从标注文件读取)
    # 生成1-3个随机文本框(x1,y1,x2,y2,x3,y3,x4,y4)
    num_boxes = random.randint(1, 3)
    boxes = []
    for _ in range(num_boxes):
        x1 = random.randint(50, w//2)
        y1 = random.randint(50, h//2)
        x2 = random.randint(x1+50, w-50)
        y2 = y1
        x3 = x2
        y3 = random.randint(y1+30, h-50)
        x4 = x1
        y4 = y3
        boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
    boxes = np.array(boxes, dtype=np.float32)

    # 模拟文本识别标注(随机字符串)
    text_len = random.randint(3, 8)
    text = ''.join(random.choices(CHARS[1:], k=text_len))  # 不含空白符

    # 数据增强(训练阶段)
    if self.is_train:
        if random.random() > 0.5:
            image = cv2.flip(image, 1)  # 水平翻转

    # 预处理
    detect_img = self._preprocess_detect(image)
    recog_img, recog_label = self._preprocess_recog(image, text)

    return {
        "detect_img": detect_img,
        "boxes": boxes,
        "recog_img": recog_img,
        "recog_label": recog_label,
        "text_len": len(text)
    }

def _preprocess_detect(self, image):
    """文本检测图像预处理"""
    img = cv2.resize(image, DETECT_INPUT_SIZE)
    img = img.transpose(2, 0, 1) / 255.0  # (C, H, W),归一化
    return torch.from_numpy(img).float()

def _preprocess_recog(self, image, text):
    """文本识别图像预处理"""
    # 随机裁剪一个文本区域(模拟检测结果)
    h, w = image.shape[:2]
    x1 = random.randint(0, w - RECOG_INPUT_SIZE[0])
    y1 = random.randint(0, h - RECOG_INPUT_SIZE[1])
    roi = image[y1:y1+RECOG_INPUT_SIZE[1], x1:x1+RECOG_INPUT_SIZE[0]]
    
    # 转为灰度图
    roi_gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
    roi_gray = roi_gray / 255.0  # 归一化
    img = torch.from_numpy(roi_gray).unsqueeze(0).float()  # (1, H, W)
    
    # 文本转标签(字符→索引)
    label = [CHARS.index(c) for c in text]
    return img, torch.tensor(label, dtype=torch.long)

---------------------------- 2. 模型定义 ----------------------------

class EAST(nn.Module):
"""文本检测模型(EAST)"""
def init(self):
super(EAST, self).init()
# 特征提取(简化版VGG)
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (64, 256, 256)

        nn.Conv2d(64, 128, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # (128, 128, 128)
        
        nn.Conv2d(128, 256, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(256, 256, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # (256, 64, 64)
        
        nn.Conv2d(256, 512, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(512, 512, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # (512, 32, 32)
    )
    
    # 上采样融合
    self.upconv1 = nn.Conv2d(512, 256, 1)
    self.upconv2 = nn.Conv2d(256, 128, 1)
    self.upconv3 = nn.Conv2d(128, 64, 1)
    
    # 输出层:得分图(文本/非文本)+ 几何信息(8维:4个点坐标)
    self.score_map = nn.Conv2d(64, 1, 1)
    self.geometry = nn.Conv2d(64, 8, 1)

def forward(self, x):
    # x: (B, 3, 512, 512)
    x = self.features(x)  # (B, 512, 32, 32)
    
    # 上采样到64x64
    x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)  # (B,512,64,64)
    x = self.upconv1(x)  # (B,256,64,64)
    
    # 上采样到128x128
    x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)  # (B,256,128,128)
    x = self.upconv2(x)  # (B,128,128,128)
    
    # 上采样到256x256
    x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)  # (B,128,256,256)
    x = self.upconv3(x)  # (B,64,256,256)
    
    # 输出
    score = torch.sigmoid(self.score_map(x))  # (B,1,256,256) 文本概率
    geometry = self.geometry(x)  # (B,8,256,256) 边界框坐标
    return score, geometry

class CRNN(nn.Module):
"""文本识别模型(CRNN)"""
def init(self, num_classes=NUM_CLASSES):
super(CRNN, self).init()
# CNN特征提取
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (64, 16, 50)

        nn.Conv2d(64, 128, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # (128, 8, 25)
        
        nn.Conv2d(128, 256, 3, 1, 1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        
        nn.Conv2d(256, 256, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d((2, 2), (2, 1), (0, 1)),  # (256, 4, 25)
        
        nn.Conv2d(256, 512, 3, 1, 1),
        nn.BatchNorm2d(512),
        nn.ReLU(),
        
        nn.Conv2d(512, 512, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d((2, 2), (2, 1), (0, 1)),  # (512, 1, 25)
    )
    
    # RNN序列建模
    self.rnn = nn.LSTM(
        input_size=512,
        hidden_size=256,
        num_layers=2,
        bidirectional=True,
        batch_first=True
    )
    
    # 输出层
    self.fc = nn.Linear(512, num_classes)  # 双向LSTM输出512(256*2)

def forward(self, x):
    # x: (B, 1, 32, 100)
    x = self.cnn(x)  # (B, 512, 1, 25)
    x = x.squeeze(2)  # (B, 512, 25)
    x = x.permute(0, 2, 1)  # (B, 25, 512) 序列长度25
    
    x, _ = self.rnn(x)  # (B, 25, 512)
    x = self.fc(x)  # (B, 25, num_classes)
    return x

---------------------------- 3. 损失函数 ----------------------------

class EASTLoss(nn.Module):
"""EAST模型损失函数"""
def init(self):
super(EASTLoss, self).init()

def forward(self, score, geometry, score_label, geometry_label, mask):
    # 文本得分损失(二分类交叉熵)
    score_loss = F.binary_cross_entropy(score * mask, score_label * mask)
    
    # 几何损失(平滑L1)
    geometry_loss = F.smooth_l1_loss(geometry * mask, geometry_label * mask)
    return score_loss + 10 * geometry_loss  # 几何损失权重更高

---------------------------- 4. 训练函数 ----------------------------

def train():
# 数据集和数据加载器
train_dataset = OCRDataset(root_dir="train_images") # 替换为你的训练集路径
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# 初始化模型、损失函数、优化器
east_model = EAST().to(DEVICE)
crnn_model = CRNN().to(DEVICE)
east_criterion = EASTLoss()
crnn_criterion = nn.CTCLoss(blank=0, reduction='mean')  # CTC损失
east_optimizer = optim.Adam(east_model.parameters(), lr=DETECT_LR)
crnn_optimizer = optim.Adam(crnn_model.parameters(), lr=RECOG_LR)

# 训练循环
for epoch in range(EPOCHS):
    east_model.train()
    crnn_model.train()
    total_east_loss = 0.0
    total_crnn_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        # 文本检测训练
        detect_img = batch["detect_img"].to(DEVICE)
        boxes = batch["boxes"]  # 模拟数据,实际需生成score_label和geometry_label
        
        # 模拟标签(实际需根据boxes生成)
        B, H, W = detect_img.shape[0], DETECT_INPUT_SIZE[0], DETECT_INPUT_SIZE[1]
        score_label = torch.zeros(B, 1, H//2, W//2).to(DEVICE)  # 简化:全0
        geometry_label = torch.zeros(B, 8, H//2, W//2).to(DEVICE)
        mask = torch.ones_like(score_label).to(DEVICE)  # 简化:全1
        
        east_optimizer.zero_grad()
        score, geometry = east_model(detect_img)
        east_loss = east_criterion(score, geometry, score_label, geometry_label, mask)
        east_loss.backward()
        east_optimizer.step()
        total_east_loss += east_loss.item()

        # 文本识别训练
        recog_img = batch["recog_img"].to(DEVICE)
        recog_label = batch["recog_label"]
        text_len = batch["text_len"]
        
        crnn_optimizer.zero_grad()
        output = crnn_model(recog_img)  # (B, 25, num_classes)
        log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)  # (25, B, num_classes)
        input_lengths = torch.full((B,), 25, dtype=torch.long).to(DEVICE)  # 序列长度25
        target_lengths = text_len.to(DEVICE)
        
        crnn_loss = crnn_criterion(log_probs, recog_label, input_lengths, target_lengths)
        crnn_loss.backward()
        crnn_optimizer.step()
        total_crnn_loss += crnn_loss.item()

    # 打印 epoch 损失
    print(f"East Loss: {total_east_loss/len(train_loader):.4f}")
    print(f"CRNN Loss: {total_crnn_loss/len(train_loader):.4f}")

# 保存模型
torch.save(east_model.state_dict(), "east_model.pth")
torch.save(crnn_model.state_dict(), "crnn_model.pth")
print("模型保存完成")

---------------------------- 5. 推理函数 ----------------------------

def detect_text(image, east_model):
"""使用EAST模型检测文本区域"""
h, w = image.shape[:2]
img = cv2.resize(image, DETECT_INPUT_SIZE)
img = img.transpose(2, 0, 1) / 255.0
img = torch.from_numpy(img).unsqueeze(0).float().to(DEVICE)

east_model.eval()
with torch.no_grad():
    score, geometry = east_model(img)

# 后处理:提取文本框(简化版)
score = score[0, 0].cpu().numpy()
coords = np.where(score > 0.5)  # 阈值筛选
boxes = []
for y, x in zip(coords[0], coords[1]):
    # 还原坐标到原图尺寸
    scale_y, scale_x = h / score.shape[0], w / score.shape[1]
    x1, y1 = int(x * scale_x), int(y * scale_y)
    x2, y2 = int((x + 20) * scale_x), int((y + 10) * scale_y)  # 简化:固定框大小
    boxes.append([x1, y1, x2, y2])
return boxes

def recognize_text(image, crnn_model):
"""使用CRNN模型识别文本"""
# 预处理
img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, RECOG_INPUT_SIZE)
img = img / 255.0
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

crnn_model.eval()
with torch.no_grad():
    output = crnn_model(img)  # (1, 25, num_classes)

# CTC解码
output = F.log_softmax(output, dim=2).squeeze(0).cpu().numpy()
pred = np.argmax(output, axis=1)  # 贪婪解码

# 去除重复和空白符
result = []
prev = -1
for p in pred:
    if p != prev and p != 0:
        result.append(CHARS[p])
    prev = p
return ''.join(result)

def ocr_inference(image_path):
"""完整OCR推理流程"""
# 加载模型
east_model = EAST().to(DEVICE)
crnn_model = CRNN().to(DEVICE)
east_model.load_state_dict(torch.load("east_model.pth", map_location=DEVICE))
crnn_model.load_state_dict(torch.load("crnn_model.pth", map_location=DEVICE))

# 读取图像
image = cv2.imread(image_path)
if image is None:
    return "图像读取失败"

# 文本检测
boxes = detect_text(image, east_model)
if not boxes:
    return "未检测到文本"

# 文本识别
results = []
for (x1, y1, x2, y2) in boxes:
    # 裁剪文本区域
    roi = image[y1:y2, x1:x2]
    if roi.size == 0:
        continue
    # 识别
    text = recognize_text(roi, crnn_model)
    results.append(f"文本: {text}, 位置: ({x1},{y1})-({x2},{y2})")

return results

---------------------------- 主函数 ----------------------------

if name == "main":
# 训练模型(需准备训练集)
# train()

# 推理示例(替换为你的测试图像路径)
test_image = "test_image.jpg"
results = ocr_inference(test_image)
for res in results:
    print(res)
posted @ 2025-11-07 15:23  无聊了多少人  阅读(0)  评论(0)    收藏  举报