基于Transformer的端到端图像验证码识别模型设计与实现
图像验证码识别作为一种典型的图像到序列问题,传统的 CNN-RNN 架构逐渐暴露出性能瓶颈。本文提出一种基于 Transformer 架构的端到端验证码识别方法,摒弃循环结构,利用自注意力机制全局建模字符间的依赖关系,显著提升模型对干扰验证码和变形字符的鲁棒性。实验表明该方法在合成验证码集上具有更优性能与更快推理速度。
- 引言
验证码(CAPTCHA)作为一种图像与文字混合的防爬机制,常包含扭曲字符、背景噪声、随机长度等干扰特征。传统 CNN + LSTM + CTC 的方法虽然有效,但训练与推理速度受到序列建模方式限制。
Transformer 架构近年来在 NLP 与视觉领域取得显著成果,具备强大的长距离建模能力。本文首次将纯视觉 Transformer 应用于验证码识别任务,通过 patch embedding 及位置编码,将验证码图像转化为序列进行字符解码,简化模型流程同时增强泛化能力。
- 数据生成
使用 captcha 和 PIL 生成长度 5~6 的随机图像验证码:
更多内容访问ttocr.com或联系1436423940
from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image
characters = string.ascii_uppercase + string.digits
def generate_dataset(output_dir, num_samples):
os.makedirs(output_dir, exist_ok=True)
generator = ImageCaptcha(width=160, height=60)
for i in range(num_samples):
text = ''.join(random.choices(characters, k=random.randint(5, 6)))
img = generator.generate_image(text)
img.save(f"{output_dir}/{text}.png")
generate_dataset("transformer_train", 8000)
generate_dataset("transformer_test", 1000)
3. 模型设计:Vision Transformer + 字符分类头
模型核心采用 Vision Transformer(ViT)提取图像 patch 信息,再通过 MLP 实现每个 patch 到字符的解码:
import torch.nn as nn
import torch
from einops import rearrange
class PatchEmbed(nn.Module):
def init(self, img_size=160, patch_size=20, in_channels=1, embed_dim=256):
super().init()
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H/ps, W/ps)
return rearrange(x, 'b c h w -> b (h w) c')
class TransformerEncoder(nn.Module):
def init(self, dim=256, depth=4, heads=4):
super().init()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=dim, nhead=heads) for _ in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class ViT_CAPTCHA(nn.Module):
def init(self, num_classes=36+1, max_length=6):
super().init()
self.embedding = PatchEmbed()
self.pos_embed = nn.Parameter(torch.randn(1, 24, 256))
self.encoder = TransformerEncoder()
self.decoder = nn.Linear(256, num_classes)
def forward(self, x):
x = self.embedding(x) + self.pos_embed
x = self.encoder(x)
return self.decoder(x)
输出形状为 (batch, patch_len, num_classes),其中 num_classes 包含 [A-Z0-9] 共 36 类字符 + 1 个空白符号。
- 解码与损失函数
使用 CTC Loss 实现字符对齐与训练:
ctc = nn.CTCLoss(blank=36)
def compute_loss(logits, targets, logit_lengths, target_lengths):
logits = logits.log_softmax(2).permute(1, 0, 2)
return ctc(logits, targets, logit_lengths, target_lengths)
使用 Beam Search 或 Greedy 解码生成最终字符序列。
浙公网安备 33010602011771号