PyTorch 的 CRNN 验证码识别 全流程实战

数据生成(合成验证码)

数据集与 DataLoader(含 collate)

模型实现(CRNN: CNN + BiLSTM + CTC)

训练脚本(含 loss / checkpoint)

评估与推理(greedy decode 与示例)

简易 Flask 部署接口

训练超参与实验建议

优化策略与常见问题

合法合规与伦理提醒

1 环境准备
建议使用 Python 3.8+,GPU 可加速训练。

安装必要依赖(示例):

pip install torch torchvision pillow captcha flask
如果使用 GPU,请按你机器选择对应的 torch 版本(官方安装命令)。

2 数据生成(合成验证码)
用 captcha 库快速生成合成数据方便训练。保存文件名里包含标签,便于 Dataset 读取。

保存为 data_gen.py:

data_gen.py

from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image

CHARS = string.digits + string.ascii_uppercase # 0-9 and A-Z

def generate_one(text, path):
image = ImageCaptcha(width=160, height=60)
image.write(text, path)

def random_text(min_len=4, max_len=5):
length = random.randint(min_len, max_len)
return ''.join(random.choices(CHARS, k=length))

def generate_dataset(out_dir='data/train', n=5000):
os.makedirs(out_dir, exist_ok=True)
for i in range(n):
txt = random_text()
fname = f"{txt}_{i}.png"
generate_one(txt, os.path.join(out_dir, fname))

if name == 'main':
generate_dataset('data/train', 8000)
generate_dataset('data/val', 2000)
print('done')
运行 python data_gen.py 会生成训练和验证集。

3 数据集与 DataLoader(含 collate)
用 PyTorch Dataset 读取图片并将标签编码为整数序列(注意 CTC 要求 targets 合并为 1D)。

保存为 dataset.py:

dataset.py

import torch
from torch.utils.data import Dataset
from PIL import Image
import glob
import os
import string

CHARS = string.digits + string.ascii_uppercase
CHAR2IDX = {c: i+1 for i, c in enumerate(CHARS)} # reserve 0 for blank in CTC
IDX2CHAR = {i+1: c for i, c in enumerate(CHARS)}

import torchvision.transforms as transforms

transform = transforms.Compose([
transforms.Resize((32, 160)), # height 32, width 160 fixed for simplicity
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])

class CaptchaDataset(Dataset):
def init(self, folder, transform=transform):
self.files = glob.glob(os.path.join(folder, '*.png'))
self.transform = transform

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    path = self.files[idx]
    fname = os.path.basename(path)
    label_str = fname.split('_')[0]
    img = Image.open(path).convert('L')  # grayscale
    img = self.transform(img)
    # convert label to indices (list of int)
    label = [CHAR2IDX[c] for c in label_str]
    label = torch.tensor(label, dtype=torch.long)
    return img, label, label_str

def collate_fn(batch):
# batch: list of (img, label, label_str)
imgs = [item[0] for item in batch]
labels = [item[1] for item in batch]
label_strs = [item[2] for item in batch]
imgs = torch.stack(imgs, dim=0)
# for CTC targets need flattened targets and lengths
targets = torch.cat(labels).to(torch.long)
target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
# input_lengths for CTC will be computed from model output time steps
return imgs, targets, target_lengths, label_strs
说明:

我们把 blank 保留为索引 0,因此字符索引从 1 开始。

简化处理:图片统一大小 32x160。若希望支持可变宽,可改为 padding / resize 保留比例并调整 collate。

4 模型实现(CRNN)
CRNN 由简单的 CNN 特征提取器 + BiLSTM + 线性投影到字符类别数量(包含 blank)。CTC 接收形状 (T, N, C) 的对数概率。

保存为 model.py:

model.py

import torch
import torch.nn as nn

class CRNN(nn.Module):
def init(self, imgH=32, nc=1, nclass=1+36, nh=256):
# nclass = num_chars + 1 (blank)
super(CRNN, self).init()
self.nclass = nclass
ks = [3,3,3,3,3,3,2]
ps = [1,1,1,1,1,1,0]
ss = [1,1,1,1,1,1,1]
nm = [64,128,256,256,512,512,512]

    def conv_relu(i, batch_norm=False):
        nIn = nc if i==0 else nm[i-1]
        nOut = nm[i]
        layers = [nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])]
        if batch_norm:
            layers.append(nn.BatchNorm2d(nOut))
        layers.append(nn.ReLU(True))
        return layers

    # build conv layers similar to CRNN paper
    self.cnn = nn.Sequential(
        *conv_relu(0, False),
        nn.MaxPool2d(2,2), # 64x16x80
        *conv_relu(1, False),
        nn.MaxPool2d(2,2), # 128x8x40
        *conv_relu(2, False),
        *conv_relu(3, False),
        nn.MaxPool2d((2,1), (2,1)), # 256x4x40
        *conv_relu(4, True),
        *conv_relu(5, True),
        nn.MaxPool2d((2,1), (2,1)), # 512x2x40
        *conv_relu(6, True),
        # now feature map size ~ (batch, 512, 1, width')
    )

    # RNN accepts (seq_len, batch, input_size)
    self.rnn = nn.Sequential(
        BidirectionalLSTM(512, nh, nh),
        BidirectionalLSTM(nh, nh, nclass)
    )

def forward(self, x):
    # x: (B, C, H, W)
    conv = self.cnn(x)
    b, c, h, w = conv.size()
    assert h == 1 or h == 2 or h==4, "expecting small height"
    # collapse height dimension
    conv = conv.squeeze(2)  # (b, c, w)
    conv = conv.permute(2, 0, 1)  # (w, b, c)
    output = self.rnn(conv)  # (w, b, nclass)
    return output  # logits (not softmax)

class BidirectionalLSTM(nn.Module):
def init(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).init()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden*2, nOut)

def forward(self, input):
    recurrent, _ = self.rnn(input)
    T, b, h = recurrent.size()
    t_rec = recurrent.view(T*b, h)
    output = self.embedding(t_rec)  # (T*b, nOut)
    output = output.view(T, b, -1)
    return output

说明:

nclass = len(CHARS) + 1,+1 是 CTC 的 blank 类。

forward 返回的是 logits (T, N, C),训练时需要对其取 log_softmax -> 使用 nn.CTCLoss(PyTorch 接受 log_probs)。

5 训练脚本(train.py)
训练流程关键点:

模型输出 pred shape (T, N, C),要 log_softmax 后传给 CTCLoss;

input_lengths 为每个样本的时间步长度 T 或计算得到;

targets 为一维长向量,target_lengths 每个标签长度。

保存为 train.py:

train.py

import torch
from torch.utils.data import DataLoader
from dataset import CaptchaDataset, collate_fn
from model import CRNN
import torch.nn.functional as F
import torch.optim as optim
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
train_dataset = CaptchaDataset('data/train')
val_dataset = CaptchaDataset('data/val')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

nclass = 1 + 36
model = CRNN(nclass=nclass).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

best_acc = 0.0
for epoch in range(1, 31):
    model.train()
    total_loss = 0
    for imgs, targets, target_lengths, _ in train_loader:
        imgs = imgs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        preds = model(imgs)  # (T, N, C)
        preds_log_softmax = F.log_softmax(preds, dim=2)
        T, N, C = preds.size()
        input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
        loss = ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    val_acc = validate(model, val_loader)
    print(f"Epoch {epoch} Loss {avg_loss:.4f} ValAcc {val_acc:.4f}")
    # save best
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_crnn.pth")
print("training finished")

def decode_greedy(preds):
# preds: (T, N, C) logits
_, max_idx = preds.max(2) # (T, N)
max_idx = max_idx.transpose(0,1).cpu().numpy() # (N, T)
results = []
from dataset import IDX2CHAR
for seq in max_idx:
# collapse repeats and remove blanks (0)
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx), ''))
prev = idx
results.append(''.join(out))
return results

def validate(model, val_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, targets, target_lengths, label_strs in val_loader:
imgs = imgs.to(device)
preds = model(imgs)
preds = F.log_softmax(preds, dim=2)
preds_raw = preds.exp() # probs
pred_texts = decode_greedy(preds_raw)
for p, gt in zip(pred_texts, label_strs):
if p == gt:
correct += 1
total += 1
return correct/total if total>0 else 0

if name == 'main':
train()
说明:

input_lengths 我这里用 T(时间步与批次一致)。若使用可变宽图片需要计算每张图对应的时间步长度(conv 后的宽度)。

decode_greedy 做贪心解码并移除连续重复与 blank。

6 评估与推理(inference)
推理示例 infer.py:

infer.py

import torch
from model import CRNN
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((32,160)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])

def load_model(path='best_crnn.pth'):
model = CRNN(nclass=1+36).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model

def predict(model, image_path):
img = Image.open(image_path).convert('L')
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(x) # (T, N, C)
probs = F.softmax(preds, dim=2)
# greedy decode similar to train decode_greedy but for single sample
_, max_idx = probs.max(2) # (T, N)
seq = max_idx[:,0].cpu().numpy().tolist()
from dataset import IDX2CHAR
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx),''))
prev = idx
return ''.join(out)

if name == 'main':
model = load_model('best_crnn.pth')
print(predict(model, 'data/val/SAMPLE.png'))
7 简易 Flask 部署示例
部署为简单 HTTP 接口 app.py:

app.py

from flask import Flask, request, jsonify
from infer import load_model, predict
import os

app = Flask(name)
model = load_model('best_crnn.pth')

@app.route('/predict', methods=['POST'])
def predict_api():
if 'file' not in request.files:
return jsonify({'error':'no file'}), 400
f = request.files['file']
fname = 'tmp_upload.png'
f.save(fname)
text = predict(model, fname)
os.remove(fname)
return jsonify({'text': text})

if name == 'main':
app.run(host='0.0.0.0', port=5000)
启动后可用 curl 上传图片测试:

curl -F "file=@captcha.png" http://127.0.0.1:5000/predict

posted @ 2025-08-12 23:34  ttocr、com  阅读(25)  评论(0)    收藏  举报