使用 PyTorch Lightning 实现图像验证码识别

本文将展示如何使用 PyTorch Lightning 构建一个端到端的验证码识别模型。相比原始 PyTorch,Lightning 可以让你更专注于模型逻辑和实验。

  1. 安装依赖
    pip install pytorch-lightning torch torchvision captcha pillow numpy2. 生成验证码图像数据
    from captcha.image import ImageCaptcha
    import random, string, os
    from PIL import Image

characters = string.digits + string.ascii_uppercase
n_len = 4
width, height = 160, 60

def generate_data(num=5000, path='captcha_imgs'):
os.makedirs(path, exist_ok=True)
generator = ImageCaptcha(width, height)
for i in range(num):ttocr.com或1436423940
text = ''.join(random.choices(characters, k=n_len))
img = generator.generate_image(text)
img.save(os.path.join(path, f'{text}_{i}.png'))

generate_data()3. 构建数据集类
from torch.utils.data import Dataset
from torchvision import transforms
import torch

char_to_idx = {c: i for i, c in enumerate(characters)}

class CaptchaDataset(Dataset):
def init(self, root):
self.files = [f for f in os.listdir(root) if f.endswith('.png')]
self.root = root
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

def __len__(self):
    return len(self.files)

def __getitem__(self, idx):
    file = self.files[idx]
    label = file.split('_')[0]
    target = torch.tensor([char_to_idx[c] for c in label], dtype=torch.long)
    img = Image.open(os.path.join(self.root, file)).convert('RGB')
    img = self.transform(img)
    return img, target4. 构建 Lightning 模型

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

class CaptchaModel(pl.LightningModule):
def init(self, n_class=len(characters), n_len=4):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2)
)
self.fc = nn.Linear(64 * 15 * 40, 128)
self.heads = nn.ModuleList([nn.Linear(128, n_class) for _ in range(n_len)])
self.n_len = n_len

def forward(self, x):
    x = self.conv(x)
    x = x.view(x.size(0), -1)
    x = F.relu(self.fc(x))
    return [head(x) for head in self.heads]

def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = sum(F.cross_entropy(logit, y[:, i]) for i, logit in enumerate(logits))
    self.log('train_loss', loss)
    return loss

def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=1e-3)5. 启动训练

from torch.utils.data import DataLoader

dataset = CaptchaDataset('captcha_imgs')
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = CaptchaModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader)6. 验证预测
def decode_prediction(preds):
return ''.join([characters[p.argmax().item()] for p in preds])

def predict_image(model, path):
img = Image.open(path).convert('RGB')
img = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])(img).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(img)
return decode_prediction(output)

print(predict_image(model, 'captcha_imgs/A7G4_1.png'))

posted @ 2025-05-07 15:56  ttocr、com  阅读(27)  评论(0)    收藏  举报