PyTorch 构建轻量级验证码识别模型

一、引言
在数字化浪潮下,验证码作为网站防爬和防机器人机制的重要组成部分,其识别问题一直是图像处理与深度学习领域的热点。相比传统的 OCR 工具,本文提出一种基于 PyTorch 的轻量级卷积神经网络,用于处理干扰较强的字符型验证码。

二、项目目标
实现验证码图像的生成与标注自动化
更多内容访问ttocr.com或联系1436423940
构建轻量级 CNN 模型,兼顾识别率与速度

支持训练与评估流程全自动化

验证码示例特征如下:

尺寸:140x50 像素

长度:5 位字母+数字

干扰:背景噪声、扭曲、线条干扰

三、环境配置

pip install torch torchvision matplotlib pillow captcha
四、验证码生成脚本 gen_data.py

from captcha.image import ImageCaptcha
import random, os, string
from PIL import Image

def generate_dataset(n=5000, path='dataset'):
os.makedirs(path, exist_ok=True)
chars = string.ascii_uppercase + string.digits
captcha = ImageCaptcha(width=140, height=50)
for i in range(n):
text = ''.join(random.choices(chars, k=5))
img = captcha.generate_image(text)
img.save(f"{path}/{text}_{i}.png")

generate_dataset()
五、数据读取模块 dataset.py

import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import string

CHARS = string.ascii_uppercase + string.digits
char2idx = {c: i for i, c in enumerate(CHARS)}

class CaptchaDataset(Dataset):
def init(self, root):
self.files = [os.path.join(root, f) for f in os.listdir(root)]

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    path = self.files[idx]
    label = os.path.basename(path).split('_')[0]
    img = Image.open(path).convert('L').resize((140, 50))
    img = torch.tensor(list(img.getdata()), dtype=torch.float32).view(1, 50, 140) / 255
    label_tensor = torch.tensor([char2idx[c] for c in label], dtype=torch.long)
    return img, label_tensor

六、CNN模型定义 model.py

import torch.nn as nn

class SimpleCNN(nn.Module):
def init(self):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 12 * 35, 256),
nn.ReLU(),
)
self.out = nn.ModuleList([nn.Linear(256, 36) for _ in range(5)])

def forward(self, x):
    x = self.conv(x)
    x = self.fc(x)
    return [o(x) for o in self.out]

七、训练逻辑 train.py

from model import SimpleCNN
from dataset import CaptchaDataset, CHARS
import torch, torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
dataset = CaptchaDataset("dataset")
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
model.train()
total_loss = 0
for img, label in train_loader:
img, label = img.to(device), label.to(device)
preds = model(img)
loss = sum(F.cross_entropy(p, label[:, i]) for i, p in enumerate(preds))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.2f}")
torch.save(model.state_dict(), "captcha_model.pth")
八、预测验证 predict.py

from model import SimpleCNN
from dataset import CHARS
from PIL import Image
import torch
import os

model = SimpleCNN()
model.load_state_dict(torch.load("captcha_model.pth", map_location="cpu"))
model.eval()

def predict(path):
img = Image.open(path).convert('L').resize((140, 50))
img = torch.tensor(list(img.getdata()), dtype=torch.float32).view(1, 1, 50, 140) / 255
with torch.no_grad():
outputs = model(img)
pred = ''.join([CHARS[o.argmax()] for o in outputs])
return pred

print(predict("dataset/A7XLP_0.png"))

posted @ 2025-07-29 18:54  ttocr、com  阅读(18)  评论(0)    收藏  举报