使用Python 与 Seq2Seq Transformer 的可变长度验证码识别
本文提出一种端到端的验证码识别方法:用卷积网络将整张验证码编码为特征序列,使用带位置编码的 Transformer 解码器逐字符自回归生成结果,通过教师强制训练与交叉熵损失优化。该方法无需字符切割与对齐,天然支持变长验证码,且易于加入语言先验(例如字符约束或温度采样)。
一、方法概览
数据生成:captcha 库合成 4 到 6 位随机验证码,字符集为 0–9 和 A–Z。
编码器:浅层 CNN 提取特征后按宽度方向展平为时序特征。
位置编码:为编码器输出与解码器输入添加正余弦位置编码。
解码器:多层 Transformer Decoder 自回归预测下一个字符。
训练:起始符
推理:贪心或束搜索,直到生成
二、数据与字典
import os, random, string
from captcha.image import ImageCaptcha
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
chars = string.digits + string.ascii_uppercase
itos = ['
stoi = {c:i for i,c in enumerate(itos)}
PAD, SOS, EOS = stoi['
vocab_size = len(itos)
def synthesize(dst, n=3000, w=180, h=60):
os.makedirs(dst, exist_ok=True)
gen = ImageCaptcha(width=w, height=h)
for i in range(n):
L = random.randint(4, 6)
text = ''.join(random.choices(chars, k=L))
gen.write(text, os.path.join(dst, f'{text}_{i}.png'))
示例(如已有数据可跳过)
synthesize('data/train', 6000); synthesize('data/val', 800)
class CaptchaSeqDataset(Dataset):
def init(self, root, img_h=64, img_w=192):
self.files = [os.path.join(root,f) for f in os.listdir(root) if f.endswith('.png')]
self.tf = T.Compose([
T.Grayscale(),
T.Resize((img_h, img_w)),
T.ToTensor(),
T.Normalize((0.5,),(0.5,))
])
def len(self): return len(self.files)
def encode(self, s):
return [SOS] + [stoi[c] for c in s] + [EOS]
def getitem(self, i):
p = self.files[i]
label = os.path.basename(p).split('_')[0]
img = Image.open(p).convert('RGB')
img = self.tf(img) # 1xHxW
tgt = torch.tensor(self.encode(label), dtype=torch.long)
return img, tgt, label
def collate(batch):
imgs, tgts, labels = zip(*batch)
imgs = torch.stack(imgs,0)
maxlen = max(t.size(0) for t in tgts)
pad_tgts = torch.full((len(tgts), maxlen), PAD, dtype=torch.long)
for i,t in enumerate(tgts):
pad_tgts[i,:t.size(0)] = t
return imgs, pad_tgts, labels
三、模型结构
import math
import torch.nn as nn
class PositionalEncoding(nn.Module):
def init(self, d_model, max_len=512):
super().init()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float()(-math.log(10000.0)/d_model))
pe[:,0::2] = torch.sin(posdiv)
pe[:,1::2] = torch.cos(pos*div)
self.register_buffer('pe', pe.unsqueeze(0)) # 1 x T x C
def forward(self, x):
return x + self.pe[:,:x.size(1)]
class Encoder(nn.Module):
def init(self, d_model=256):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2,2), # H/2, W/2
nn.Conv2d(64,128,3, padding=1), nn.ReLU(), nn.MaxPool2d(2,2), # H/4, W/4
nn.Conv2d(128,d_model,3, padding=1), nn.ReLU()
)
self.pe = PositionalEncoding(d_model)
def forward(self, x):
f = self.cnn(x) # B x C x H' x W'
B,C,H,W = f.shape
seq = f.permute(0,3,1,2).contiguous().view(B, W, CH) # 按宽度展开
proj = nn.Linear(CH, C, bias=False).to(f.device) # 动态线性映射至 d_model
seq = proj(seq)
return self.pe(seq) # B x T x C
class Decoder(nn.Module):
def init(self, d_model=256, nhead=8, nlayers=4, vocab=vocab_size):
super().init()
self.embed = nn.Embedding(vocab, d_model, padding_idx=PAD)
self.pe = PositionalEncoding(d_model)
layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=512)
self.dec = nn.TransformerDecoder(layer, num_layers=nlayers)
self.out = nn.Linear(d_model, vocab)
def forward(self, tgt, memory):
# tgt: B x L
tgt_emb = self.pe(self.embed(tgt)) # B x L x C
# 生成下三角 mask,防止窥视未来
L = tgt_emb.size(1)
causal_mask = torch.triu(torch.ones(L, L, device=tgt.device)*float('-inf'), diagonal=1)
out = self.dec(
tgt_emb.transpose(0,1), # L x B x C
memory.transpose(0,1), # T x B x C
tgt_mask=causal_mask
).transpose(0,1) # B x L x C
return self.out(out) # B x L x V
class Seq2SeqCaptcha(nn.Module):
def init(self, d_model=256, nhead=8, nlayers=4, vocab=vocab_size):
super().init()
self.enc = Encoder(d_model)
self.dec = Decoder(d_model, nhead, nlayers, vocab)
def forward(self, img, tgt_inp):
mem = self.enc(img)
logits = self.dec(tgt_inp, mem)
return logits
四、训练过程
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2SeqCaptcha().to(device)
opt = optim.Adam(model.parameters(), lr=2e-4)
def shift_tgt(tgt):
# 输入序列:
inp = tgt[:,:-1]
out = tgt[:,1:]
return inp, out
def train_epoch(loader):
model.train()
total, n = 0.0, 0
for imgs, tgts, _ in loader:
imgs, tgts = imgs.to(device), tgts.to(device)
inp, out = shift_tgt(tgts)
logits = model(imgs, inp) # B x L-1 x V
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
out.reshape(-1),
ignore_index=PAD,
label_smoothing=0.1
)
opt.zero_grad(); loss.backward(); opt.step()
total += loss.item(); n += 1
return total / max(n,1)
示例
train_ds = CaptchaSeqDataset('data/train')
val_ds = CaptchaSeqDataset('data/val')
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate)
for epoch in range(20):
tr = train_epoch(train_loader)
print(f'Epoch {epoch+1}: loss={tr:.4f}')
五、推理与解码
@torch.no_grad()
def greedy_decode(model, img, max_len=8):
model.eval()
mem = model.enc(img) # 1 x T x C
ys = torch.tensor([[SOS]], device=img.device, dtype=torch.long) # 1 x 1
for _ in range(max_len):
logits = model.dec(ys, mem) # 1 x L x V
next_id = logits[:,-1,:].argmax(-1, keepdim=True) # 1 x 1
ys = torch.cat([ys, next_id], dim=1)
if next_id.item() == EOS:
break
ids = ys[0,1:] # 去掉
tokens = []
for i in ids:
if i.item() == EOS: break
tokens.append(itos[i.item()])
return ''.join(tokens)
浙公网安备 33010602011771号