import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from torch.nn import CTCLoss
import numpy as np
# -------------------------- 核心优化1:束搜索解码(替换贪婪解码,提升准确率)--------------------------
def beam_search_decode(output, beam_size=3, blank=0):
chars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" # 可替换为中文字符集
num_classes = len(chars)
seq_len, batch_size, _ = output.shape
output = torch.softmax(output, dim=2).cpu().numpy() # 转为概率分布
# 初始化束:每个样本的候选序列((序列, 概率))
beams = [([], 1.0)] * batch_size
for t in range(seq_len):
new_beams = [[] for _ in range(batch_size)]
for batch_idx in range(batch_size):
for seq, prob in beams[batch_idx]:
for c in range(num_classes):
p = output[t, batch_idx, c]
if p < 1e-6: # 过滤极低概率字符
continue
new_prob = prob * p
if c == blank:
# 空白字符:序列不变
new_beams[batch_idx].append((seq, new_prob))
else:
# 非空白字符:避免连续重复
if seq and seq[-1] == c:
continue
new_seq = seq + [c]
new_beams[batch_idx].append((new_seq, new_prob))
# 按概率排序,保留前beam_size个候选
new_beams[batch_idx] = sorted(new_beams[batch_idx], key=lambda x: x[1], reverse=True)[:beam_size]
beams = new_beams
# 提取每个样本概率最高的序列
texts = []
for batch_idx in range(batch_size):
best_seq = max(beams[batch_idx], key=lambda x: x[1])[0]
text = ''.join([chars[c] for c in best_seq])
texts.append(text)
return texts
# -------------------------- 核心优化2:适配自定义字符集(支持中文)--------------------------
class CustomCharDataset(TextDataset):
def __init__(self, img_dir, label_dict, char_set, transform=None):
super().__init__(img_dir, label_dict, transform)
self.char_set = char_set # 自定义字符集(如中文+数字+符号)
self.char_to_idx = {char: idx+1 for idx, char in enumerate(char_set)} # 空白符占0位
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
# 标签转为自定义字符集的索引
label_text = self.label_dict[self.img_names[idx]]
label = [self.char_to_idx[char] for char in label_text if char in self.char_to_idx]
return image, torch.tensor(label)
# -------------------------- 原有模块复用(模型、数据集基类)--------------------------
class TextDataset(Dataset):
def __init__(self, img_dir, label_dict, transform=None):
self.img_dir = img_dir
self.img_names = os.listdir(img_dir)
self.transform = transform
self.label_dict = label_dict
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
raise NotImplementedError("需继承并重写__getitem__")
class CRNN(nn.Module):
def __init__(self, hidden_size=256, num_classes=62):
super(CRNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2,1), (2,1)),
nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2,1), (2,1)),
nn.Conv2d(512, 512, 2), nn.ReLU()
)
self.rnn = nn.LSTM(512, hidden_size, bidirectional=True, num_layers=2, batch_first=False)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
cnn_out = self.cnn(x)
rnn_in = cnn_out.squeeze(2).permute(2, 0, 1)
rnn_out, _ = self.rnn(rnn_in)
out = self.fc(rnn_out)
return out
# -------------------------- 主流程(训练+推理)--------------------------
if __name__ == "__main__":
# 1. 配置自定义字符集(示例:中文+数字+字母,可按需修改)
custom_char_set = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ一二三四五六七八九十"
num_classes = len(custom_char_set) + 1 # +1 留空白符位置
# 2. 数据预处理
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((32, 128)), # 适配更长文本,调整宽度为128
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 3. 数据集准备(模拟中文+数字标签,实际替换为真实数据)
img_dir = "custom_text_images"
os.makedirs(img_dir, exist_ok=True)
# 模拟标签:包含中文、数字、字母
label_dict = {f"img{i}.png": f"测试{i%10}a{custom_char_set[i%20]}" for i in range(100)}
dataset = CustomCharDataset(img_dir, label_dict, custom_char_set, transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# 4. 模型初始化
model = CRNN(num_classes=num_classes)
criterion = CTCLoss(blank=0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 5. 训练(保留原有逻辑)
epochs = 15
model.train()
for epoch in range(epochs):
total_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
input_lengths = torch.full((images.size(0),), outputs.size(0), dtype=torch.long)
target_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
loss = criterion(outputs, labels, input_lengths, target_lengths)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}")
# 6. 推理(使用束搜索解码)
model.eval()
test_img_path = "test_chinese.png" # 测试图像(含中文/数字/字母)
test_image = Image.open(test_img_path).convert("RGB")
test_tensor = transform(test_image).unsqueeze(0)
with torch.no_grad():
test_output = model(test_tensor)
result = beam_search_decode(test_output, beam_size=5) # 束宽设为5,平衡速度与准确率
print(f"识别结果: {result[0]}")