使用 PyTorch Lightning 实现图像验证码识别
本文将展示如何使用 PyTorch Lightning 构建一个端到端的验证码识别模型。相比原始 PyTorch,Lightning 可以让你更专注于模型逻辑和实验。
- 安装依赖
pip install pytorch-lightning torch torchvision captcha pillow numpy2. 生成验证码图像数据
from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image
characters = string.digits + string.ascii_uppercase
n_len = 4
width, height = 160, 60
def generate_data(num=5000, path='captcha_imgs'):
os.makedirs(path, exist_ok=True)
generator = ImageCaptcha(width, height)
for i in range(num):ttocr.com或1436423940
text = ''.join(random.choices(characters, k=n_len))
img = generator.generate_image(text)
img.save(os.path.join(path, f'{text}_{i}.png'))
generate_data()3. 构建数据集类
from torch.utils.data import Dataset
from torchvision import transforms
import torch
char_to_idx = {c: i for i, c in enumerate(characters)}
class CaptchaDataset(Dataset):
def init(self, root):
self.files = [f for f in os.listdir(root) if f.endswith('.png')]
self.root = root
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file = self.files[idx]
label = file.split('_')[0]
target = torch.tensor([char_to_idx[c] for c in label], dtype=torch.long)
img = Image.open(os.path.join(self.root, file)).convert('RGB')
img = self.transform(img)
return img, target4. 构建 Lightning 模型
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
class CaptchaModel(pl.LightningModule):
def init(self, n_class=len(characters), n_len=4):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2)
)
self.fc = nn.Linear(64 * 15 * 40, 128)
self.heads = nn.ModuleList([nn.Linear(128, n_class) for _ in range(n_len)])
self.n_len = n_len
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc(x))
return [head(x) for head in self.heads]
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = sum(F.cross_entropy(logit, y[:, i]) for i, logit in enumerate(logits))
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)5. 启动训练
from torch.utils.data import DataLoader
dataset = CaptchaDataset('captcha_imgs')
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = CaptchaModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader)6. 验证预测
def decode_prediction(preds):
return ''.join([characters[p.argmax().item()] for p in preds])
def predict_image(model, path):
img = Image.open(path).convert('RGB')
img = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])(img).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(img)
return decode_prediction(output)
print(predict_image(model, 'captcha_imgs/A7G4_1.png'))
浙公网安备 33010602011771号