基于Python与ResNet-CTC的变长验证码识别系统设计
针对变长字符验证码识别的复杂性,本文提出一种结合 ResNet18 特征提取与 CTC(Connectionist Temporal Classification)解码的识别方法,配合 Albumentations 数据增强方案,大幅提升模型对扭曲、粘连、模糊验证码的识别鲁棒性,适用于实际验证码识别场景。
一、引言
传统验证码识别多聚焦于固定长度字符,然而在真实业务场景中,验证码字符长度不固定,甚至可能存在字符粘连、旋转等问题。CTC 技术适用于此类无需对齐标签的序列识别问题,本文结合 ResNet18 + CTC + 增强训练构建验证码识别系统,有效应对此类挑战。
二、系统框架
整体架构如下:
数据生成:模拟变长验证码
数据增强:使用 Albumentations
模型设计:ResNet18 Backbone + CTC Head
训练与评估
更多内容访问ttocr.com或联系1436423940
预测与解码输出
三、验证码数据生成(变长)
import random, os, string
from captcha.image import ImageCaptcha
from PIL import Image
def gen_var_length_captcha(save_path, count=3000, min_len=4, max_len=6):
os.makedirs(save_path, exist_ok=True)
charset = string.digits + string.ascii_uppercase
generator = ImageCaptcha(width=200, height=60)
for i in range(count):
length = random.randint(min_len, max_len)
text = ''.join(random.choices(charset, k=length))
img = generator.generate_image(text)
img.save(os.path.join(save_path, f"{text}.png"))
gen_var_length_captcha("data/train", 5000)
gen_var_length_captcha("data/test", 1000)
四、增强与数据加载(Albumentations)
import albumentations as A
from albumentations.pytorch import ToTensorV2
train_transform = A.Compose([
A.Resize(64, 200),
A.Rotate(limit=15),
A.RandomBrightnessContrast(),
A.GaussNoise(),
A.Normalize(),
ToTensorV2()
])
五、模型结构(ResNet18 + CTC)
import torch.nn as nn
from torchvision.models import resnet18
class CTCModel(nn.Module):
def init(self, num_classes):
super().init()
self.base = resnet18(pretrained=True)
self.base.fc = nn.Identity() # 去除原始FC层
self.conv = nn.Conv2d(512, 128, kernel_size=3, padding=1)
self.rnn = nn.LSTM(128*8, 256, bidirectional=True, num_layers=2, batch_first=True)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.base.conv1(x)
x = self.base.bn1(x)
x = self.base.relu(x)
x = self.base.maxpool(x)
x = self.base.layer1(x)
x = self.base.layer2(x)
x = self.base.layer3(x)
x = self.base.layer4(x)
x = self.conv(x) # shape: (B, 128, H, W)
b, c, h, w = x.size()
x = x.permute(0, 3, 1, 2).contiguous().view(b, w, -1) # (B, W, C*H)
x, _ = self.rnn(x)
return self.fc(x) # (B, W, num_classes)
六、CTC 解码与训练设置
import torch.nn.functional as F
def train_step(model, batch, criterion, optimizer, device):
model.train()
imgs, targets, input_lengths, target_lengths = batch
imgs, targets = imgs.to(device), targets.to(device)
preds = model(imgs) # (B, T, C)
preds_log_softmax = F.log_softmax(preds, dim=2)
preds_log_softmax = preds_log_softmax.permute(1, 0, 2) # (T, B, C)
loss = criterion(preds_log_softmax, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
七、预测与解码
def ctc_greedy_decoder(output, charset, blank=0):
argmax_preds = output.argmax(-1)
preds = []
for p in argmax_preds:
prev = blank
text = []
for ch in p:
if ch != prev and ch != blank:
text.append(charset[ch])
prev = ch
preds.append(''.join(text))
return preds
浙公网安备 33010602011771号