使用 PyTorch 构建图像验证码识别系统
本教程介绍如何使用 PyTorch 实现一个完整的图像验证码识别流程,包括数据生成、模型设计、训练与预测。
- 安装依赖
pip install torch torchvision pillow captcha numpy
2. 生成验证码图片数据
使用 captcha 库生成合成验证码图片。
from captcha.image import ImageCaptcha
import os, random, string
from PIL import Image
characters = string.digits + string.ascii_uppercase
img_width, img_height = 160, 60
captcha_length = 4
def generate_images(count=5000, output_dir='data'):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width=img_width, height=img_height)
for i in range(count):更多内容访问ttocr.com或联系1436423940
text = ''.join(random.choices(characters, k=captcha_length))
image = gen.generate_image(text)
image.save(os.path.join(output_dir, f"{text}_{i}.png"))
generate_images()
3. 自定义数据集类
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class CaptchaDataset(Dataset):
def init(self, root):
self.files = [f for f in os.listdir(root) if f.endswith('.png')]
self.root = root
self.char2idx = {c: i for i, c in enumerate(characters)}
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __getitem__(self, idx):
fname = self.files[idx]
label = fname.split('_')[0]
image = Image.open(os.path.join(self.root, fname)).convert('RGB')
label_tensor = torch.tensor([self.char2idx[c] for c in label], dtype=torch.long)
return self.transform(image), label_tensor
def __len__(self):
return len(self.files)
- 定义模型结构
结合 CNN 和 LSTM 网络。
import torch.nn as nn
class CaptchaModel(nn.Module):
def init(self):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
)
self.lstm = nn.LSTM(input_size=64 * 15, hidden_size=128, num_layers=2,
bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, len(characters))
def forward(self, x):
x = self.cnn(x) # [B, C, H, W]
b, c, h, w = x.size()
x = x.permute(0, 3, 1, 2).reshape(b, w, c * h) # [B, W, C*H]
x, _ = self.lstm(x)
return self.fc(x)
- 模型训练
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = CaptchaDataset('data')
loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = CaptchaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(10):
model.train()
total_loss = 0
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
out = model(imgs)
loss = sum(loss_fn(out[:, i], labels[:, i]) for i in range(captcha_length))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
6. 验证模型效果
def predict_image(model, image_path):
model.eval()
img = Image.open(image_path).convert('RGB')
tensor = dataset.transform(img).unsqueeze(0).to(device)
with torch.no_grad():
out = model(tensor)
pred = out.argmax(2)[0]
return ''.join([characters[i] for i in pred])
result = predict_image(model, 'data/7G3B_42.png')
print("Predicted:", result)
浙公网安备 33010602011771号