文字识别
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import string
import random
from tqdm import tqdm
---------------------------- 配置参数 ----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {DEVICE}")
字符集(0-9, a-z, 空白符)
CHARS = ' ' + string.digits + string.ascii_lowercase # 空白符索引为0
NUM_CLASSES = len(CHARS)
图像尺寸
DETECT_INPUT_SIZE = (512, 512) # 文本检测输入尺寸
RECOG_INPUT_SIZE = (100, 32) # 文本识别输入尺寸(宽, 高)
训练参数
BATCH_SIZE = 8
DETECT_LR = 1e-4
RECOG_LR = 1e-3
EPOCHS = 10
---------------------------- 1. 数据集定义 ----------------------------
class OCRDataset(Dataset):
"""OCR数据集(包含文本检测和识别的标注)"""
def init(self, root_dir, is_train=True):
self.root_dir = root_dir
self.is_train = is_train
self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
if f.endswith(('.jpg', '.png'))]
# 实际使用时需替换为真实标注文件(这里用随机数据模拟)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 读取图像
img_path = self.image_paths[idx]
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
# 模拟文本检测标注(实际需从标注文件读取)
# 生成1-3个随机文本框(x1,y1,x2,y2,x3,y3,x4,y4)
num_boxes = random.randint(1, 3)
boxes = []
for _ in range(num_boxes):
x1 = random.randint(50, w//2)
y1 = random.randint(50, h//2)
x2 = random.randint(x1+50, w-50)
y2 = y1
x3 = x2
y3 = random.randint(y1+30, h-50)
x4 = x1
y4 = y3
boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
boxes = np.array(boxes, dtype=np.float32)
# 模拟文本识别标注(随机字符串)
text_len = random.randint(3, 8)
text = ''.join(random.choices(CHARS[1:], k=text_len)) # 不含空白符
# 数据增强(训练阶段)
if self.is_train:
if random.random() > 0.5:
image = cv2.flip(image, 1) # 水平翻转
# 预处理
detect_img = self._preprocess_detect(image)
recog_img, recog_label = self._preprocess_recog(image, text)
return {
"detect_img": detect_img,
"boxes": boxes,
"recog_img": recog_img,
"recog_label": recog_label,
"text_len": len(text)
}
def _preprocess_detect(self, image):
"""文本检测图像预处理"""
img = cv2.resize(image, DETECT_INPUT_SIZE)
img = img.transpose(2, 0, 1) / 255.0 # (C, H, W),归一化
return torch.from_numpy(img).float()
def _preprocess_recog(self, image, text):
"""文本识别图像预处理"""
# 随机裁剪一个文本区域(模拟检测结果)
h, w = image.shape[:2]
x1 = random.randint(0, w - RECOG_INPUT_SIZE[0])
y1 = random.randint(0, h - RECOG_INPUT_SIZE[1])
roi = image[y1:y1+RECOG_INPUT_SIZE[1], x1:x1+RECOG_INPUT_SIZE[0]]
# 转为灰度图
roi_gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
roi_gray = roi_gray / 255.0 # 归一化
img = torch.from_numpy(roi_gray).unsqueeze(0).float() # (1, H, W)
# 文本转标签(字符→索引)
label = [CHARS.index(c) for c in text]
return img, torch.tensor(label, dtype=torch.long)
---------------------------- 2. 模型定义 ----------------------------
class EAST(nn.Module):
"""文本检测模型(EAST)"""
def init(self):
super(EAST, self).init()
# 特征提取(简化版VGG)
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (64, 256, 256)
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (128, 128, 128)
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (256, 64, 64)
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (512, 32, 32)
)
# 上采样融合
self.upconv1 = nn.Conv2d(512, 256, 1)
self.upconv2 = nn.Conv2d(256, 128, 1)
self.upconv3 = nn.Conv2d(128, 64, 1)
# 输出层:得分图(文本/非文本)+ 几何信息(8维:4个点坐标)
self.score_map = nn.Conv2d(64, 1, 1)
self.geometry = nn.Conv2d(64, 8, 1)
def forward(self, x):
# x: (B, 3, 512, 512)
x = self.features(x) # (B, 512, 32, 32)
# 上采样到64x64
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) # (B,512,64,64)
x = self.upconv1(x) # (B,256,64,64)
# 上采样到128x128
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) # (B,256,128,128)
x = self.upconv2(x) # (B,128,128,128)
# 上采样到256x256
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) # (B,128,256,256)
x = self.upconv3(x) # (B,64,256,256)
# 输出
score = torch.sigmoid(self.score_map(x)) # (B,1,256,256) 文本概率
geometry = self.geometry(x) # (B,8,256,256) 边界框坐标
return score, geometry
class CRNN(nn.Module):
"""文本识别模型(CRNN)"""
def init(self, num_classes=NUM_CLASSES):
super(CRNN, self).init()
# CNN特征提取
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (64, 16, 50)
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # (128, 8, 25)
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)), # (256, 4, 25)
nn.Conv2d(256, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)), # (512, 1, 25)
)
# RNN序列建模
self.rnn = nn.LSTM(
input_size=512,
hidden_size=256,
num_layers=2,
bidirectional=True,
batch_first=True
)
# 输出层
self.fc = nn.Linear(512, num_classes) # 双向LSTM输出512(256*2)
def forward(self, x):
# x: (B, 1, 32, 100)
x = self.cnn(x) # (B, 512, 1, 25)
x = x.squeeze(2) # (B, 512, 25)
x = x.permute(0, 2, 1) # (B, 25, 512) 序列长度25
x, _ = self.rnn(x) # (B, 25, 512)
x = self.fc(x) # (B, 25, num_classes)
return x
---------------------------- 3. 损失函数 ----------------------------
class EASTLoss(nn.Module):
"""EAST模型损失函数"""
def init(self):
super(EASTLoss, self).init()
def forward(self, score, geometry, score_label, geometry_label, mask):
# 文本得分损失(二分类交叉熵)
score_loss = F.binary_cross_entropy(score * mask, score_label * mask)
# 几何损失(平滑L1)
geometry_loss = F.smooth_l1_loss(geometry * mask, geometry_label * mask)
return score_loss + 10 * geometry_loss # 几何损失权重更高
---------------------------- 4. 训练函数 ----------------------------
def train():
# 数据集和数据加载器
train_dataset = OCRDataset(root_dir="train_images") # 替换为你的训练集路径
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
# 初始化模型、损失函数、优化器
east_model = EAST().to(DEVICE)
crnn_model = CRNN().to(DEVICE)
east_criterion = EASTLoss()
crnn_criterion = nn.CTCLoss(blank=0, reduction='mean') # CTC损失
east_optimizer = optim.Adam(east_model.parameters(), lr=DETECT_LR)
crnn_optimizer = optim.Adam(crnn_model.parameters(), lr=RECOG_LR)
# 训练循环
for epoch in range(EPOCHS):
east_model.train()
crnn_model.train()
total_east_loss = 0.0
total_crnn_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
# 文本检测训练
detect_img = batch["detect_img"].to(DEVICE)
boxes = batch["boxes"] # 模拟数据,实际需生成score_label和geometry_label
# 模拟标签(实际需根据boxes生成)
B, H, W = detect_img.shape[0], DETECT_INPUT_SIZE[0], DETECT_INPUT_SIZE[1]
score_label = torch.zeros(B, 1, H//2, W//2).to(DEVICE) # 简化:全0
geometry_label = torch.zeros(B, 8, H//2, W//2).to(DEVICE)
mask = torch.ones_like(score_label).to(DEVICE) # 简化:全1
east_optimizer.zero_grad()
score, geometry = east_model(detect_img)
east_loss = east_criterion(score, geometry, score_label, geometry_label, mask)
east_loss.backward()
east_optimizer.step()
total_east_loss += east_loss.item()
# 文本识别训练
recog_img = batch["recog_img"].to(DEVICE)
recog_label = batch["recog_label"]
text_len = batch["text_len"]
crnn_optimizer.zero_grad()
output = crnn_model(recog_img) # (B, 25, num_classes)
log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2) # (25, B, num_classes)
input_lengths = torch.full((B,), 25, dtype=torch.long).to(DEVICE) # 序列长度25
target_lengths = text_len.to(DEVICE)
crnn_loss = crnn_criterion(log_probs, recog_label, input_lengths, target_lengths)
crnn_loss.backward()
crnn_optimizer.step()
total_crnn_loss += crnn_loss.item()
# 打印 epoch 损失
print(f"East Loss: {total_east_loss/len(train_loader):.4f}")
print(f"CRNN Loss: {total_crnn_loss/len(train_loader):.4f}")
# 保存模型
torch.save(east_model.state_dict(), "east_model.pth")
torch.save(crnn_model.state_dict(), "crnn_model.pth")
print("模型保存完成")
---------------------------- 5. 推理函数 ----------------------------
def detect_text(image, east_model):
"""使用EAST模型检测文本区域"""
h, w = image.shape[:2]
img = cv2.resize(image, DETECT_INPUT_SIZE)
img = img.transpose(2, 0, 1) / 255.0
img = torch.from_numpy(img).unsqueeze(0).float().to(DEVICE)
east_model.eval()
with torch.no_grad():
score, geometry = east_model(img)
# 后处理:提取文本框(简化版)
score = score[0, 0].cpu().numpy()
coords = np.where(score > 0.5) # 阈值筛选
boxes = []
for y, x in zip(coords[0], coords[1]):
# 还原坐标到原图尺寸
scale_y, scale_x = h / score.shape[0], w / score.shape[1]
x1, y1 = int(x * scale_x), int(y * scale_y)
x2, y2 = int((x + 20) * scale_x), int((y + 10) * scale_y) # 简化:固定框大小
boxes.append([x1, y1, x2, y2])
return boxes
def recognize_text(image, crnn_model):
"""使用CRNN模型识别文本"""
# 预处理
img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, RECOG_INPUT_SIZE)
img = img / 255.0
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(DEVICE)
crnn_model.eval()
with torch.no_grad():
output = crnn_model(img) # (1, 25, num_classes)
# CTC解码
output = F.log_softmax(output, dim=2).squeeze(0).cpu().numpy()
pred = np.argmax(output, axis=1) # 贪婪解码
# 去除重复和空白符
result = []
prev = -1
for p in pred:
if p != prev and p != 0:
result.append(CHARS[p])
prev = p
return ''.join(result)
def ocr_inference(image_path):
"""完整OCR推理流程"""
# 加载模型
east_model = EAST().to(DEVICE)
crnn_model = CRNN().to(DEVICE)
east_model.load_state_dict(torch.load("east_model.pth", map_location=DEVICE))
crnn_model.load_state_dict(torch.load("crnn_model.pth", map_location=DEVICE))
# 读取图像
image = cv2.imread(image_path)
if image is None:
return "图像读取失败"
# 文本检测
boxes = detect_text(image, east_model)
if not boxes:
return "未检测到文本"
# 文本识别
results = []
for (x1, y1, x2, y2) in boxes:
# 裁剪文本区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
continue
# 识别
text = recognize_text(roi, crnn_model)
results.append(f"文本: {text}, 位置: ({x1},{y1})-({x2},{y2})")
return results
---------------------------- 主函数 ----------------------------
if name == "main":
# 训练模型(需准备训练集)
# train()
# 推理示例(替换为你的测试图像路径)
test_image = "test_image.jpg"
results = ocr_inference(test_image)
for res in results:
print(res)

浙公网安备 33010602011771号