使用 PyTorch + CTC Loss 实现验证码识别系统
本项目使用 PyTorch 实现一个使用 CTC Loss 的验证码识别模型,适用于变长或字符重叠的验证码图像。
- 安装依赖
pip install torch torchvision captcha pillow numpy2. 生成验证码数据集
from captcha.image import ImageCaptcha
import os, random, string
from PIL import Image
characters = string.digits + string.ascii_uppercase
min_len, max_len = 4, 6
width, height = 200, 60
def generate_dataset(output_dir='ctc_captcha', count=10000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width, height)
for i in range(count):
length = random.randint(min_len, max_len)
text = ''.join(random.choices(characters, k=length))
image = gen.generate_image(text)
image.save(f'{output_dir}/{text}_{i}.png')
更多内容访问ttocr.com或联系1436423940
generate_dataset()3. 自定义 Dataset 类
import torch
from torch.utils.data import Dataset
from torchvision import transforms
char_to_idx = {c: i + 1 for i, c in enumerate(characters)} # 0 for blank
idx_to_char = {i + 1: c for i, c in enumerate(characters)}
n_classes = len(characters) + 1 # +1 for CTC blank
class CTCCaptchaDataset(Dataset):
def init(self, path):
self.images = [f for f in os.listdir(path) if f.endswith('.png')]
self.path = path
self.transform = transforms.Compose([
transforms.Resize((60, 200)),
transforms.ToTensor()
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
fname = self.images[idx]
label_str = fname.split('_')[0]
label = [char_to_idx[c] for c in label_str]
img = Image.open(os.path.join(self.path, fname)).convert('L')
return self.transform(img), torch.tensor(label, dtype=torch.long), len(label_str)4. 构建模型(CNN + BiLSTM)
import torch.nn as nn
class CTCModel(nn.Module):
def init(self):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
)
self.rnn = nn.LSTM(64 * 15, 128, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, n_classes)
def forward(self, x):
x = self.cnn(x)
x = x.permute(0, 3, 1, 2) # [B, W, C, H]
b, w, c, h = x.shape
x = x.view(b, w, c * h)
x, _ = self.rnn(x)
x = self.fc(x)
return x.log_softmax(2)5. 模型训练
from torch.nn.functional import ctc_loss
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CTCModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dataset = CTCCaptchaDataset('ctc_captcha')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)
for epoch in range(10):
model.train()
for batch in dataloader:
images, labels, label_lens = zip(*batch)
images = torch.stack(images).to(device)
labels = torch.cat(labels).to(device)
label_lens = torch.tensor(label_lens, dtype=torch.long)
logits = model(images)
log_probs = logits.permute(1, 0, 2) # T, N, C
input_lens = torch.full((logits.size(0),), logits.size(1), dtype=torch.long)
loss = ctc_loss(log_probs, labels, input_lens, label_lens)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")6. 解码预测(CTC 解码)
def greedy_decode(output):
out = torch.argmax(output, dim=2).squeeze(0).cpu().numpy()
prev = -1
result = []
for i in out:
if i != prev and i != 0:
result.append(idx_to_char[i])
prev = i
return ''.join(result)
def predict_image(image_path):
img = Image.open(image_path).convert('L').resize((200, 60))
tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor)
return greedy_decode(output)
print(predict_image('ctc_captcha/9XZ2_1.png'))
浙公网安备 33010602011771号