使用 PyTorch 实现验证码识别系统
本教程介绍如何使用 Python 和 PyTorch 从零实现一个图像验证码识别模型,适合初学者快速上手。
第一步:安装依赖
pip install torch torchvision pillow captcha
第二步:生成验证码图片ttocr.com或1436423940
from captcha.image import ImageCaptcha
import os, random, string
from PIL import Image
characters = string.digits + string.ascii_uppercase
width, height, n_len = 160, 60, 4
def generate_data(save_dir='data', count=5000):
os.makedirs(save_dir, exist_ok=True)
image_gen = ImageCaptcha(width, height)
for i in range(count):
text = ''.join(random.choices(characters, k=n_len))
image = image_gen.generate_image(text)
image.save(f'{save_dir}/{text}_{i}.png')
generate_data()
第三步:构建数据集
from torch.utils.data import Dataset
from torchvision import transforms
import torch
class CaptchaDataset(Dataset):
def init(self, folder):
self.files = [f for f in os.listdir(folder) if f.endswith('.png')]
self.folder = folder
self.char2idx = {c: i for i, c in enumerate(characters)}
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __getitem__(self, idx):
filename = self.files[idx]
label = filename.split('_')[0]
image = Image.open(os.path.join(self.folder, filename)).convert('RGB')
label_tensor = torch.tensor([self.char2idx[c] for c in label])
return self.transform(image), label_tensor
def __len__(self):
return len(self.files)
dataset = CaptchaDataset('data')
第四步:构建模型结构
import torch.nn as nn
class CaptchaModel(nn.Module):
def init(self):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
self.rnn = nn.LSTM(64 * 15, 128, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, len(characters))
def forward(self, x):
x = self.cnn(x)
b, c, h, w = x.size()
x = x.permute(0, 3, 1, 2).reshape(b, w, c * h)
x, _ = self.rnn(x)
x = self.fc(x)
return x
第五步:训练模型
from torch.utils.data import DataLoader
model = CaptchaModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
loader = DataLoader(dataset, batch_size=64, shuffle=True)
for epoch in range(10):
model.train()
total_loss = 0
for imgs, labels in loader:
imgs, labels = imgs.cuda(), labels.cuda()
outputs = model(imgs) # [B, W, n_class]
loss = sum(loss_fn(outputs[:, i, :], labels[:, i]) for i in range(n_len))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1} Loss: {total_loss:.4f}')
第六步:进行预测
def predict(model, image_path):
model.eval()
image = Image.open(image_path).convert('RGB')
tensor = dataset.transform(image).unsqueeze(0).cuda()
with torch.no_grad():
output = model(tensor)
pred = output.argmax(dim=2)[0]
return ''.join([characters[i] for i in pred])
print(predict(model, 'data/G5K9_12.png'))
浙公网安备 33010602011771号