图像验证码识别:基于 PyTorch 的简单实现
本文将介绍如何使用 Python 和 PyTorch 构建一个简单的图像验证码识别系统。主要包括数据生成、模型训练和预测三个部分。
- 安装依赖
pip install torch torchvision pillow captcha numpy
2. 生成验证码图片
from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image
characters = string.digits + string.ascii_uppercase
n_len = 4ttocr.com或1436423940
width, height = 160, 60
def generate_dataset(output_dir='data', count=5000):
os.makedirs(output_dir, exist_ok=True)
generator = ImageCaptcha(width, height)
for i in range(count):
text = ''.join(random.choices(characters, k=n_len))
image = generator.generate_image(text)
image.save(os.path.join(output_dir, f'{text}_{i}.png'))
generate_dataset()
3. 构建数据集类
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
char_to_idx = {c: i for i, c in enumerate(characters)}
class CaptchaDataset(Dataset):
def init(self, folder):
self.folder = folder
self.files = [f for f in os.listdir(folder) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filename = self.files[idx]
label = filename.split('_')[0]
img = Image.open(os.path.join(self.folder, filename)).convert('RGB')
return self.transform(img), torch.tensor([char_to_idx[c] for c in label], dtype=torch.long)
- 构建模型(CNN)
import torch.nn as nn
class SimpleCNN(nn.Module):
def init(self, num_classes, seq_len):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Flatten()
)
self.fc = nn.Sequential(
nn.Linear(64 * 15 * 40, 1024), nn.ReLU(),
nn.Linear(1024, seq_len * num_classes)
)
self.seq_len = seq_len
self.num_classes = num_classes
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x.view(-1, self.seq_len, self.num_classes)
- 训练模型
from torch.utils.data import DataLoader
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = CaptchaDataset('data')
loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = SimpleCNN(num_classes=len(characters), seq_len=n_len).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
total_loss = 0
model.train()
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
preds = model(images)
loss = sum([criterion(preds[:, i, :], labels[:, i]) for i in range(n_len)])
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
6. 测试预测
def predict_image(model, image_path):
model.eval()
image = Image.open(image_path).convert('RGB')
image = dataset.transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
pred = torch.argmax(output, dim=2)[0]
result = ''.join([characters[i] for i in pred])
return result
test_img = 'data/B8F3_1.png'
print("预测结果:", predict_image(model, test_img))
浙公网安备 33010602011771号