使用 PyTorch 构建图像验证码识别系统(CTC Loss 实战)
图像验证码识别是一种典型的图像转文本任务。本文将带你构建一个基于 CNN + BiLSTM + CTC 的验证码识别模型,支持不定长字符识别,具有较强泛化能力。
- 环境依赖
pip install torch torchvision pillow captcha2. 生成验证码图片
我们使用 captcha 库生成随机长度的验证码(长度 3~5):
import os, random, string
from captcha.image import ImageCaptcha
from PIL import Image
CHARSET = string.digits + string.ascii_uppercase
WIDTH, HEIGHT = 160, 60
def generate_images(path='data', count=8000):
os.makedirs(path, exist_ok=True)
generator = ImageCaptcha(width=WIDTH, height=HEIGHT)
for i in range(count):www.tmocr.com或q1092685548
length = random.randint(3, 5)
label = ''.join(random.choices(CHARSET, k=length))
image = generator.generate_image(label)
image.save(os.path.join(path, f'{label}_{i}.png'))
generate_images()3. 构建 PyTorch 数据集
import torch
from torch.utils.data import Dataset
from torchvision import transforms
char2idx = {c: i+1 for i, c in enumerate(CHARSET)} # CTC blank=0
idx2char = {v: k for k, v in char2idx.items()}
class CaptchaCTCDataset(Dataset):
def init(self, root):
self.root = root
self.files = [f for f in os.listdir(root) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filename = self.files[idx]
label_str = filename.split('_')[0]
label = torch.tensor([char2idx[c] for c in label_str], dtype=torch.long)
img = Image.open(os.path.join(self.root, filename)).convert('L')
return self.transform(img), label4. 模型结构:CNN + BiLSTM
import torch.nn as nn
class CRNN_CTC(nn.Module):
def init(self, num_classes):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU()
)
self.rnn = nn.LSTM(256 * 15, 128, num_layers=2, bidirectional=True)
self.fc = nn.Linear(128 * 2, num_classes)
def forward(self, x):
x = self.cnn(x) # [B, C, H, W]
b, c, h, w = x.size()
x = x.permute(3, 0, 1, 2) # [W, B, C, H]
x = x.reshape(w, b, c * h)
x, _ = self.rnn(x)
x = self.fc(x)
return x5. 训练模型
from torch.nn import CTCLoss
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = CaptchaCTCDataset('data')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda batch: zip(*batch))
model = CRNN_CTC(num_classes=len(CHARSET)+1).to(device)
criterion = CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
for imgs, labels in dataloader:
imgs = torch.stack(list(imgs)).to(device)
targets = torch.cat(list(labels))
target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
input_lengths = torch.full(size=(imgs.size(0),), fill_value=imgs.size(-1)//4, dtype=torch.long)
logits = model(imgs) # [T, B, C]
log_probs = logits.log_softmax(2)
loss = criterion(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')6. 解码预测结果
def greedy_decode(output):
output = output.permute(1, 0, 2) # [B, T, C]
pred = torch.argmax(output, dim=2) # [B, T]
results = []
for seq in pred:
chars = []
prev = 0
for i in seq:
if i != prev and i != 0:
chars.append(idx2char[i.item()])
prev = i
results.append(''.join(chars))
return results
model.eval()
sample, _ = dataset[0]
pred = model(sample.unsqueeze(0).to(device))
print("预测结果:", greedy_decode(pred))
浙公网安备 33010602011771号