基于 PyTorch 的多任务验证码识别模型设计与实现
一、研究背景
验证码识别作为深度学习在计算机视觉领域的重要应用之一,广泛用于登录认证与反爬虫机制。传统 OCR 技术在应对复杂干扰和多字符位置对齐时效果欠佳。为提升识别效率与准确率,本文设计了一种多任务学习架构,将验证码识别建模为字符位置并行分类问题。
二、系统架构概述
本项目采用 PyTorch 框架,整体流程包括:
验证码图像生成:合成训练样本,控制背景干扰与扭曲难度;
数据预处理:图像归一化,标签编码;
更多内容访问ttocr.com或联系1436423940
多头 CNN 模型设计:共享特征提取层,分支输出每个字符位置;
训练与评估流程:交叉熵损失、多输出联合优化。
三、数据集准备
我们使用 captcha 库生成 6000 张验证码图像,其中:
每张图片包含 5 个字符(数字+字母)
干扰包括噪点、曲线、旋转等
分辨率:160×60 像素,灰度图
from captcha.image import ImageCaptcha
import string, random, os
def generate_captcha(path="data/train", count=6000):
os.makedirs(path, exist_ok=True)
captcha = ImageCaptcha(width=160, height=60)
charset = string.ascii_uppercase + string.digits
for i in range(count):
text = ''.join(random.choices(charset, k=5))
captcha.write(text, f"{path}/{text}_{i}.png")
generate_captcha()
四、模型设计
采用轻量 CNN + 多输出结构,每个输出对应一个字符位,具体结构如下:
import torch.nn as nn
class MultiHeadCaptchaModel(nn.Module):
def init(self, num_classes=36, num_chars=5):
super().init()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 7 * 20, 512),
nn.ReLU()
)
self.heads = nn.ModuleList([nn.Linear(512, num_classes) for _ in range(num_chars)])
def forward(self, x):
x = self.backbone(x)
x = self.fc(x)
return [head(x) for head in self.heads]
五、训练过程
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from dataset import CaptchaDataset # 自定义读取类
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiHeadCaptchaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_loader = DataLoader(CaptchaDataset("data/train"), batch_size=64, shuffle=True)
for epoch in range(20):
model.train()
total_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
outputs = model(x)
loss = sum(F.cross_entropy(o, y[:, i]) for i, o in enumerate(outputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
六、识别流程
def predict(img_path):
model.eval()
img = Image.open(img_path).convert("L").resize((160, 60))
img = torch.tensor(np.array(img) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
pred = ''.join([CHARS[o.argmax()] for o in outputs])
return pred