使用 PyTorch 构建图像验证码识别系统(CTC Loss 实战)

图像验证码识别是一种典型的图像转文本任务。本文将带你构建一个基于 CNN + BiLSTM + CTC 的验证码识别模型,支持不定长字符识别,具有较强泛化能力。

  1. 环境依赖
    pip install torch torchvision pillow captcha2. 生成验证码图片
    我们使用 ​​captcha​​ 库生成随机长度的验证码(长度 3~5):
    import os, random, string
    from captcha.image import ImageCaptcha
    from PIL import Image

CHARSET = string.digits + string.ascii_uppercase
WIDTH, HEIGHT = 160, 60

def generate_images(path='data', count=8000):
os.makedirs(path, exist_ok=True)
generator = ImageCaptcha(width=WIDTH, height=HEIGHT)
for i in range(count):www.tmocr.com或q1092685548
length = random.randint(3, 5)
label = ''.join(random.choices(CHARSET, k=length))
image = generator.generate_image(label)
image.save(os.path.join(path, f'{label}_{i}.png'))

generate_images()3. 构建 PyTorch 数据集
import torch
from torch.utils.data import Dataset
from torchvision import transforms

char2idx = {c: i+1 for i, c in enumerate(CHARSET)} # CTC blank=0
idx2char = {v: k for k, v in char2idx.items()}

class CaptchaCTCDataset(Dataset):
def init(self, root):
self.root = root
self.files = [f for f in os.listdir(root) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    filename = self.files[idx]
    label_str = filename.split('_')[0]
    label = torch.tensor([char2idx[c] for c in label_str], dtype=torch.long)
    img = Image.open(os.path.join(self.root, filename)).convert('L')
    return self.transform(img), label4. 模型结构:CNN + BiLSTM

import torch.nn as nn

class CRNN_CTC(nn.Module):
def init(self, num_classes):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU()
)
self.rnn = nn.LSTM(256 * 15, 128, num_layers=2, bidirectional=True)
self.fc = nn.Linear(128 * 2, num_classes)

def forward(self, x):
    x = self.cnn(x)            # [B, C, H, W]
    b, c, h, w = x.size()
    x = x.permute(3, 0, 1, 2)  # [W, B, C, H]
    x = x.reshape(w, b, c * h)
    x, _ = self.rnn(x)
    x = self.fc(x)
    return x5. 训练模型

from torch.nn import CTCLoss
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = CaptchaCTCDataset('data')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda batch: zip(*batch))

model = CRNN_CTC(num_classes=len(CHARSET)+1).to(device)
criterion = CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
model.train()
for imgs, labels in dataloader:
imgs = torch.stack(list(imgs)).to(device)
targets = torch.cat(list(labels))
target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
input_lengths = torch.full(size=(imgs.size(0),), fill_value=imgs.size(-1)//4, dtype=torch.long)

    logits = model(imgs)              # [T, B, C]
    log_probs = logits.log_softmax(2)
    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')6. 解码预测结果

def greedy_decode(output):
output = output.permute(1, 0, 2) # [B, T, C]
pred = torch.argmax(output, dim=2) # [B, T]
results = []
for seq in pred:
chars = []
prev = 0
for i in seq:
if i != prev and i != 0:
chars.append(idx2char[i.item()])
prev = i
results.append(''.join(chars))
return results

model.eval()
sample, _ = dataset[0]
pred = model(sample.unsqueeze(0).to(device))
print("预测结果:", greedy_decode(pred))

posted @ 2025-05-08 17:36  ttocr、com  阅读(36)  评论(0)    收藏  举报