用PyTorch实现图像验证码识别
本文介绍如何使用Python和PyTorch实现一个图像验证码识别系统,包括数据集生成、模型构建、训练和测试。
- 安装必要库
使用pip安装所需依赖:
pip install torch torchvision pillow captcha numpy2. 生成验证码数据集
使用captcha库自动生成带有数字和大写字母的验证码图像。
from captcha.image import ImageCaptcha
import os
import random
import string
characters = string.digits + string.ascii_uppercase
captcha_length = 4
image_width, image_height = 160, 60
def generate_dataset(save_dir="captcha_data", num_samples=10000):
os.makedirs(save_dir, exist_ok=True)
generator = ImageCaptcha(width=image_width, height=image_height)
for idx in range(num_samples):
text = ''.join(random.choices(characters, k=captcha_length))
img = generator.generate_image(text)
img.save(os.path.join(save_dir, f"{text}_{idx}.png"))
generate_dataset()3. 构建数据加载器
定义自定义数据集类,读取图像和标签。
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch
class CaptchaDataset(Dataset):
def init(self, root_dir):
self.root_dir = root_dir
self.files = os.listdir(root_dir)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.char_to_idx = {ch: idx for idx, ch in enumerate(characters)}
def __len__(self):
return len(self.files)
def __getitem__(self, index):
filename = self.files[index]
label = filename.split('_')[0]
image = Image.open(os.path.join(self.root_dir, filename)).convert('RGB')
image = self.transform(image)
label_tensor = torch.tensor([self.char_to_idx[c] for c in label], dtype=torch.long)
return image, label_tensor
更多内容访问ttocr.com或联系1436423940
dataset = CaptchaDataset("captcha_data")
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)4. 构建识别模型
采用卷积神经网络提取特征,LSTM处理序列信息。
import torch.nn as nn
class CaptchaModel(nn.Module):
def init(self):
super(CaptchaModel, self).init()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.MaxPool2d((2,1))
)
self.rnn = nn.LSTM(128*7, 128, num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, len(characters))
def forward(self, x):
x = self.cnn(x)
x = x.permute(0, 3, 1, 2)
b, w, c, h = x.size()
x = x.view(b, w, c*h)
x, _ = self.rnn(x)
x = self.fc(x)
return x5. 训练模型
设置设备、优化器和损失函数,并进行训练。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CaptchaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(20):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = sum(criterion(outputs[:, i, :], labels[:, i]) for i in range(captcha_length))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")6. 测试模型
预测单张验证码图片的内容。
def predict(model, image_path):
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
prediction = output.argmax(dim=2)
result = ''.join([characters[idx] for idx in prediction[0]])
return result
test_image = "captcha_data/7XQ9_0.png"
print("Predicted:", predict(model, test_image))
浙公网安备 33010602011771号