使用 PyTorch + CTC Loss 实现验证码识别系统

本项目使用 PyTorch 实现一个使用 CTC Loss 的验证码识别模型,适用于变长或字符重叠的验证码图像。

  1. 安装依赖
    pip install torch torchvision captcha pillow numpy2. 生成验证码数据集
    from captcha.image import ImageCaptcha
    import os, random, string
    from PIL import Image

characters = string.digits + string.ascii_uppercase
min_len, max_len = 4, 6
width, height = 200, 60

def generate_dataset(output_dir='ctc_captcha', count=10000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width, height)
for i in range(count):
length = random.randint(min_len, max_len)
text = ''.join(random.choices(characters, k=length))
image = gen.generate_image(text)
image.save(f'{output_dir}/{text}_{i}.png')
更多内容访问ttocr.com或联系1436423940
generate_dataset()3. 自定义 Dataset 类
import torch
from torch.utils.data import Dataset
from torchvision import transforms

char_to_idx = {c: i + 1 for i, c in enumerate(characters)} # 0 for blank
idx_to_char = {i + 1: c for i, c in enumerate(characters)}
n_classes = len(characters) + 1 # +1 for CTC blank

class CTCCaptchaDataset(Dataset):
def init(self, path):
self.images = [f for f in os.listdir(path) if f.endswith('.png')]
self.path = path
self.transform = transforms.Compose([
transforms.Resize((60, 200)),
transforms.ToTensor()
])

def __len__(self):
    return len(self.images)

def __getitem__(self, idx):
    fname = self.images[idx]
    label_str = fname.split('_')[0]
    label = [char_to_idx[c] for c in label_str]
    img = Image.open(os.path.join(self.path, fname)).convert('L')
    return self.transform(img), torch.tensor(label, dtype=torch.long), len(label_str)4. 构建模型(CNN + BiLSTM)

import torch.nn as nn

class CTCModel(nn.Module):
def init(self):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
)
self.rnn = nn.LSTM(64 * 15, 128, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, n_classes)

def forward(self, x):
    x = self.cnn(x)
    x = x.permute(0, 3, 1, 2)  # [B, W, C, H]
    b, w, c, h = x.shape
    x = x.view(b, w, c * h)
    x, _ = self.rnn(x)
    x = self.fc(x)
    return x.log_softmax(2)5. 模型训练

from torch.nn.functional import ctc_loss
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CTCModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

dataset = CTCCaptchaDataset('ctc_captcha')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)

for epoch in range(10):
model.train()
for batch in dataloader:
images, labels, label_lens = zip(*batch)
images = torch.stack(images).to(device)
labels = torch.cat(labels).to(device)
label_lens = torch.tensor(label_lens, dtype=torch.long)

    logits = model(images)
    log_probs = logits.permute(1, 0, 2)  # T, N, C

    input_lens = torch.full((logits.size(0),), logits.size(1), dtype=torch.long)
    loss = ctc_loss(log_probs, labels, input_lens, label_lens)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")6. 解码预测(CTC 解码)

def greedy_decode(output):
out = torch.argmax(output, dim=2).squeeze(0).cpu().numpy()
prev = -1
result = []
for i in out:
if i != prev and i != 0:
result.append(idx_to_char[i])
prev = i
return ''.join(result)

def predict_image(image_path):
img = Image.open(image_path).convert('L').resize((200, 60))
tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor)
return greedy_decode(output)

print(predict_image('ctc_captcha/9XZ2_1.png'))

posted @ 2025-04-30 16:02  ttocr、com  阅读(14)  评论(0)    收藏  举报