手把手教你用深度学习识别验证码
验证码识别是计算机视觉中的经典问题,今天我将带你用Python和PyTorch从零实现一个验证码识别系统。无需复杂理论,跟着做就能得到实际可用的模型!
一、5分钟搭建开发环境
首先安装必要的库:
bash
pip install torch torchvision pillow opencv-python numpy matplotlib
二、快速生成训练数据
我们可以用Python自动生成带标签的验证码:
python
from PIL import Image, ImageDraw, ImageFont
import random
import os
def generate_captcha(text, width=160, height=60):
"""生成单张验证码"""
img = Image.new('RGB', (width, height), color=(255, 255, 255))
font = ImageFont.truetype('arial.ttf', 36)
d = ImageDraw.Draw(img)
# 随机扭曲
for i, char in enumerate(text):
d.text((20+i*25+random.randint(-5,5),
10+random.randint(-5,5)),
char, font=font, fill=(0,0,0))
# 添加干扰线
for _ in range(5):
d.line([(random.randint(0,width), random.randint(0,height),
(random.randint(0,width), random.randint(0,height))],
fill=(random.randint(50,200),random.randint(50,200),random.randint(50,200)),
width=2)
return img
生成1000张训练数据
chars = '0123456789ABCDEFGHJKLMNPQRSTUVWXYZ' # 去除容易混淆的字符
for i in range(1000):
text = ''.join(random.choices(chars, k=5))
img = generate_captcha(text)
img.save(f'train/{text}.png')
三、极简数据预处理
python
import cv2
import numpy as np
def preprocess(img_path):
"""图像预处理"""
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
img = cv2.resize(img, (160, 60))
return img.astype(np.float32) / 255.0
示例
sample = preprocess('train/ABC12.png')
plt.imshow(sample, cmap='gray')
plt.show()
四、10行代码构建模型
使用PyTorch搭建轻量级CNN:
python
import torch
import torch.nn as nn
class CaptchaNet(nn.Module):
def init(self, num_chars):
super().init()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Linear(644020, num_chars*5) # 5个字符位置
def forward(self, x):
x = self.cnn(x)
x = x.flatten(1)
x = self.fc(x)
return x.view(-1, 5, len(chars)) # 调整为(批次,5,字符数)
五、训练代码(带进度条)
python
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class CaptchaDataset(Dataset):
def init(self, folder):
self.files = [f for f in os.listdir(folder) if f.endswith('.png')]
self.folder = folder
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img_path = os.path.join(self.folder, self.files[idx])
label = self.files[idx].split('.')[0]
img = preprocess(img_path)
# 将标签转为数字索引
label_idx = [chars.index(c) for c in label]
return torch.tensor(img).unsqueeze(0), torch.tensor(label_idx)
数据加载
dataset = CaptchaDataset('train')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
初始化
model = CaptchaNet(len(chars))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
训练循环
for epoch in range(10):
loop = tqdm(loader)
for images, labels in loop:
outputs = model(images)
loss = criterion(outputs.permute(0,2,1), labels) # 调整维度计算损失
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_description(f'Epoch [{epoch+1}/10]')
loop.set_postfix(loss=loss.item())
六、实际使用示例
训练完成后,可以这样使用模型:
python
def predict(image_path):
model.eval()
img = preprocess(image_path)
with torch.no_grad():
output = model(torch.tensor(img).unsqueeze(0).unsqueeze(0))
pred = output.argmax(dim=2).squeeze().numpy()
return ''.join([chars[i] for i in pred])
测试
print(predict('test/XY89Z.png')) # 输出识别结果
七、性能优化技巧
数据增强:添加随机旋转、噪声等
python
在生成验证码时添加
img = img.rotate(random.randint(-15,15), expand=True)
模型量化:减小模型体积
python
model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8)
多进程加载:加速训练
python
loader = DataLoader(dataset, batch_size=32,
num_workers=4, pin_memory=True)
浙公网安备 33010602011771号