使用 PyTorch Lightning 实现图像验证码识别

本教程将展示如何使用 PyTorch Lightning 实现一个高效的图像验证码识别系统。相比原生 PyTorch,Lightning 提供了更清晰的训练流程和模块结构。

  1. 安装依赖
    pip install pytorch-lightning torch torchvision pillow captcha2. 生成验证码图像
    from captcha.image import ImageCaptcha
    import os
    import random
    import string

characters = string.digits + string.ascii_uppercase
captcha_length = 4

def generate_dataset(output_dir="captcha_images", count=10000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width=160, height=60)
for i in range(count):
text = ''.join(random.choices(characters, k=captcha_length))
img = gen.generate_image(text)
img.save(f"{output_dir}/{text}_{i}.png")
更多内容访问ttocr.com或联系1436423940
generate_dataset()3. 数据集定义
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class CaptchaDataset(Dataset):
def init(self, root):
self.files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith('.png')]
self.char_to_idx = {c: i for i, c in enumerate(characters)}
self.transform = transforms.Compose([
transforms.Resize((60, 160)),
transforms.ToTensor()
])

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    path = self.files[idx]
    label_str = os.path.basename(path).split('_')[0]
    label = torch.tensor([self.char_to_idx[c] for c in label_str], dtype=torch.long)
    image = Image.open(path).convert('RGB')
    return self.transform(image), label4. 模型定义

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

class CaptchaModel(pl.LightningModule):
def init(self):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
)
self.rnn = nn.LSTM(128 * 15, 128, bidirectional=True, batch_first=True, num_layers=2)
self.classifier = nn.Linear(256, len(characters))

def forward(self, x):
    x = self.conv(x)
    x = x.permute(0, 3, 1, 2)  # [B, W, C, H]
    B, W, C, H = x.shape
    x = x.reshape(B, W, C * H)
    x, _ = self.rnn(x)
    x = self.classifier(x)
    return x

def training_step(self, batch, batch_idx):
    images, labels = batch
    logits = self(images)
    loss = sum(F.cross_entropy(logits[:, i, :], labels[:, i]) for i in range(captcha_length))
    self.log("train_loss", loss)
    return loss

def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=1e-3)5. 训练模型

from torch.utils.data import DataLoader

dataset = CaptchaDataset("captcha_images")
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

model = CaptchaModel()
trainer = pl.Trainer(max_epochs=10, accelerator='auto')
trainer.fit(model, train_loader)6. 验证预测效果
def decode_prediction(logits):
pred = torch.argmax(logits, dim=2)[0]
return ''.join([characters[i] for i in pred])

def predict(model, image_path):
image = Image.open(image_path).convert('RGB').resize((160, 60))
image = transforms.ToTensor()(image).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(image)
return decode_prediction(output)

print(predict(model, "captcha_images/Z8K5_123.png"))

posted @ 2025-04-29 17:45  ttocr、com  阅读(20)  评论(0)    收藏  举报