使用 PyTorch Lightning 实现图像验证码识别
本教程将展示如何使用 PyTorch Lightning 实现一个高效的图像验证码识别系统。相比原生 PyTorch,Lightning 提供了更清晰的训练流程和模块结构。
- 安装依赖
pip install pytorch-lightning torch torchvision pillow captcha2. 生成验证码图像
from captcha.image import ImageCaptcha
import os
import random
import string
characters = string.digits + string.ascii_uppercase
captcha_length = 4
def generate_dataset(output_dir="captcha_images", count=10000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width=160, height=60)
for i in range(count):
text = ''.join(random.choices(characters, k=captcha_length))
img = gen.generate_image(text)
img.save(f"{output_dir}/{text}_{i}.png")
更多内容访问ttocr.com或联系1436423940
generate_dataset()3. 数据集定义
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class CaptchaDataset(Dataset):
def init(self, root):
self.files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith('.png')]
self.char_to_idx = {c: i for i, c in enumerate(characters)}
self.transform = transforms.Compose([
transforms.Resize((60, 160)),
transforms.ToTensor()
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
path = self.files[idx]
label_str = os.path.basename(path).split('_')[0]
label = torch.tensor([self.char_to_idx[c] for c in label_str], dtype=torch.long)
image = Image.open(path).convert('RGB')
return self.transform(image), label4. 模型定义
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
class CaptchaModel(pl.LightningModule):
def init(self):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
)
self.rnn = nn.LSTM(128 * 15, 128, bidirectional=True, batch_first=True, num_layers=2)
self.classifier = nn.Linear(256, len(characters))
def forward(self, x):
x = self.conv(x)
x = x.permute(0, 3, 1, 2) # [B, W, C, H]
B, W, C, H = x.shape
x = x.reshape(B, W, C * H)
x, _ = self.rnn(x)
x = self.classifier(x)
return x
def training_step(self, batch, batch_idx):
images, labels = batch
logits = self(images)
loss = sum(F.cross_entropy(logits[:, i, :], labels[:, i]) for i in range(captcha_length))
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)5. 训练模型
from torch.utils.data import DataLoader
dataset = CaptchaDataset("captcha_images")
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
model = CaptchaModel()
trainer = pl.Trainer(max_epochs=10, accelerator='auto')
trainer.fit(model, train_loader)6. 验证预测效果
def decode_prediction(logits):
pred = torch.argmax(logits, dim=2)[0]
return ''.join([characters[i] for i in pred])
def predict(model, image_path):
image = Image.open(image_path).convert('RGB').resize((160, 60))
image = transforms.ToTensor()(image).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(image)
return decode_prediction(output)
print(predict(model, "captcha_images/Z8K5_123.png"))
浙公网安备 33010602011771号