PyTorch 构建轻量级 CNN 模型实现验证码图像识别
本文介绍如何使用 Python 和 PyTorch 实现一个轻量级的卷积神经网络(CNN)模型,用于识别英文数字混合的图像验证码。通过合成数据集、搭建模型、训练和测试等步骤,实现对 5 位验证码的准确识别,过程简明,便于在实际应用中快速部署。
-
项目背景
验证码图像识别是深度学习应用中的经典案例,涉及图像处理、OCR 与模型训练等内容。相比传统的字符切割方法,现代 CNN 模型可直接进行端到端识别,大幅提升识别效率和准确率。 -
数据集构建
使用 captcha 库生成英文+数字验证码,统一为 5 个字符:
更多内容访问ttocr.com或联系1436423940
from captcha.image import ImageCaptcha
import string, random
from PIL import Image
import numpy as np
ALL_CHARS = string.ascii_letters + string.digits
def gen_text(length=5):
return ''.join(random.choices(ALL_CHARS, k=length))
def gen_image(text):
gen = ImageCaptcha(width=160, height=60)
img = gen.generate_image(text)
return img.convert('L')
def process_img(img):
img = img.resize((160, 60))
return np.array(img, dtype=np.float32) / 255.0
3. 标签编码与解码
char2idx = {c: i for i, c in enumerate(ALL_CHARS)}
idx2char = {i: c for c, i in char2idx.items()}
def encode_label(text):
return [char2idx[c] for c in text]
def decode_output(output):
pred = output.argmax(dim=2)
return [''.join([idx2char[i.item()] for i in row]) for row in pred]
4. 模型结构设计
import torch.nn as nn
class SimpleCNN(nn.Module):
def init(self, num_classes=62, length=5):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(644015, 1024), nn.ReLU(),
nn.Linear(1024, length * num_classes)
)
self.length = length
self.num_classes = num_classes
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x.view(-1, self.length, self.num_classes)
- 训练流程
import torch
import torch.nn.functional as F
import torch.optim as optim
model = SimpleCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def compute_loss(pred, label):
return sum(F.cross_entropy(pred[:, i], label[:, i]) for i in range(5))
训练循环:
for epoch in range(10):
for images, labels in dataloader:
images, labels = images.unsqueeze(1), labels
preds = model(images)
loss = compute_loss(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
6. 测试与评估
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
preds = model(images.unsqueeze(1))
texts = decode_output(preds)
reals = [''.join([idx2char[i.item()] for i in row]) for row in labels]
correct += sum(p == r for p, r in zip(texts, reals))
total += len(texts)
print(f"识别准确率: {correct / total * 100:.2f}%")
浙公网安备 33010602011771号