图像验证码识别:用 PyTorch Lightning 实现 CRNN 模型

本教程展示如何使用 PyTorch Lightning 实现一个图像验证码识别系统,具备清晰结构、易于扩展、便于训练的优点。

安装依赖
pip install pytorch-lightning torch torchvision pillow captcha2. 生成验证码图片
from captcha.image import ImageCaptcha
import string, random, os
from PIL import Image
characters = string.digits + string.ascii_uppercase
width, height, captcha_length = 160, 60, 4

def generate_captcha(output_dir="pl_captcha", num=5000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width, height)
for i in range(num):
text = ''.join(random.choices(characters, k=captcha_length))
img = gen.generate_image(text)
img.save(f"{output_dir}/{text}_{i}.png")
更多内容访问ttocr.com或联系1436423940
generate_captcha()3. 定义数据集
from torch.utils.data import Dataset
from torchvision import transforms
import torch

class CaptchaDataset(Dataset):
def init(self, data_dir):
self.data_dir = data_dir
self.files = [f for f in os.listdir(data_dir) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.char_to_idx = {c: i for i, c in enumerate(characters)}

def len(self):
return len(self.files)

def getitem(self, idx):
filename = self.files[idx]
path = os.path.join(self.data_dir, filename)
image = Image.open(path).convert('RGB')
label_text = filename.split('_')[0]
label = torch.tensor([self.char_to_idx[c] for c in label_text], dtype=torch.long)
return self.transform(image), label4. 构建模型(LightningModule)
import pytorch_lightning as pl
import torch.nn as nn
import torch

class CRNNModel(pl.LightningModule):
def init(self, num_classes=len(characters), captcha_len=4):
super().init()
self.save_hyperparameters()
self.captcha_len = captcha_len
self.model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.MaxPool2d((2, 1))
)
self.rnn = nn.LSTM(128 * 7, 128, num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, num_classes)

def forward(self, x):
x = self.model(x) # [B, C, H, W]
x = x.permute(0, 3, 1, 2) # [B, W, C, H]
b, w, c, h = x.shape
x = x.reshape(b, w, ch) # [B, W, CH]
x, _ = self.rnn(x)
x = self.fc(x) # [B, W, num_classes]
return x

def training_step(self, batch, batch_idx):
images, labels = batch
output = self(images) # [B, W, C]
loss = sum(nn.CrossEntropyLoss()(output[:, i, :], labels[:, i]) for i in range(self.captcha_len))
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)5. 准备训练入口
from torch.utils.data import DataLoader, random_split

dataset = CaptchaDataset("pl_captcha")
train_size = int(0.9 * len(dataset))
train_ds, val_ds = random_split(dataset, [train_size, len(dataset) - train_size])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)

model = CRNNModel()

trainer = pl.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, train_loader, val_loader)6. 推理函数
def predict(model, image_path):
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
x = transform(image).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(x) # [1, W, C]
pred = torch.argmax(output, dim=2)[0]
idx_to_char = {i: c for i, c in enumerate(characters)}
return ''.join([idx_to_char[i.item()] for i in pred])

print(predict(model, "pl_captcha/A8Z5_1.png"))

posted @ 2025-05-22 22:39  ttocr、com  阅读(22)  评论(0)    收藏  举报