使用 PyTorch 构建验证码识别模型
本文介绍如何使用 PyTorch 实现一个简单的图像验证码识别模型。验证码内容为 4 位数字和大写英文字母的组合,模型结构采用 CNN + LSTM。
一、准备工作
安装所需的 Python 库:
pip install torch torchvision pillow captcha numpy
二、生成验证码图像数据
使用 captcha 库生成图片数据。
from captcha.image import ImageCaptcha
import os
import random
import string
更多内容访问ttocr.com或联系1436423940
characters = string.digits + string.ascii_uppercase
width, height = 160, 60
length = 4
def generate_images(num=10000, out_dir='images'):
os.makedirs(out_dir, exist_ok=True)
generator = ImageCaptcha(width=width, height=height)
for i in range(num):
text = ''.join(random.choices(characters, k=length))
image = generator.generate_image(text)
image.save(os.path.join(out_dir, f'{text}_{i}.png'))
generate_images()
三、构建数据集加载器
使用 PyTorch 的 Dataset 和 DataLoader 类进行数据封装。
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch
class CaptchaDataset(Dataset):
def init(self, path):
self.path = path
self.files = os.listdir(path)
self.char_to_idx = {c: i for i, c in enumerate(characters)}
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def __getitem__(self, index):
file = self.files[index]
label_str = file.split('_')[0]
image = Image.open(os.path.join(self.path, file)).convert('RGB')
image = self.transform(image)
label = torch.tensor([self.char_to_idx[c] for c in label_str], dtype=torch.long)
return image, label
def __len__(self):
return len(self.files)
train_loader = DataLoader(CaptchaDataset('images'), batch_size=64, shuffle=True)
四、搭建模型结构
卷积层提取特征,LSTM 学习字符时序关系,输出每个位置的字符预测。
import torch.nn as nn
class CaptchaModel(nn.Module):
def init(self):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.MaxPool2d((2, 1))
)
self.rnn = nn.LSTM(128 * 7, 128, num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, len(characters))
def forward(self, x):
x = self.cnn(x)
x = x.permute(0, 3, 1, 2)
b, w, c, h = x.size()
x = x.reshape(b, w, c * h)
x, _ = self.rnn(x)
x = self.fc(x)
return x
五、训练模型
设置优化器、损失函数并开始训练。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CaptchaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(15):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = sum(criterion(output[:, i, :], labels[:, i]) for i in range(length))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss:.4f}')
六、测试单张图像
定义预测函数,查看模型识别效果。
def predict(model, image_path):
model.eval()
image = Image.open(image_path).convert('RGB')
image = transforms.ToTensor()(image)
image = transforms.Normalize((0.5,), (0.5,))(image)
image = image.unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
pred = torch.argmax(output, dim=2)[0]
return ''.join([characters[i] for i in pred])
test_image = 'images/4G7D_3.png'
print('识别结果:', predict(model, test_image))
浙公网安备 33010602011771号