验证码识别对比
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import sys
sys.path.append('.')
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from captcha.image import ImageCaptcha
import warnings
warnings.filterwarnings('ignore')
======================== 1. 全局常量配置 ========================
number = [str(i) for i in range(10)]
alphabet = [chr(ord('A')+i) for i in range(26)] + [chr(ord('a')+i) for i in range(26)]
alphanum = number + alphabet
CAPTCHA_SIZE = 4
IMG_HEIGHT = 60
IMG_WIDTH = 160
STUDENT_ID = "3020"
DATA_ROOT = r'C:\Users\zjl20\Desktop\pytorch666\yanzhengma\data\'
IMG_SAVE_PATH_DIGIT = os.path.join(DATA_ROOT, "image_digit")
IMG_SAVE_PATH_ALPHA = os.path.join(DATA_ROOT, "image_alphanum")
LABEL_PATH_DIGIT = os.path.join(DATA_ROOT, "label_digit.txt")
LABEL_PATH_ALPHA = os.path.join(DATA_ROOT, "label_alphanum.txt")
os.makedirs(IMG_SAVE_PATH_DIGIT, exist_ok=True)
os.makedirs(IMG_SAVE_PATH_ALPHA, exist_ok=True)
EPOCHS = 10
BATCH_SIZE = 16
LR = 0.0008
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {DEVICE}")
======================== 2. 验证码生成函数 ========================
`def random_captcha_text(char_set=number, captcha_size=CAPTCHA_SIZE):
captcha_text = [random.choice(char_set) for _ in range(captcha_size)]
return captcha_text
def gen_captcha_text_and_image(index, char_set, img_save_path, label_path):
image_generator = ImageCaptcha(width=IMG_WIDTH, height=IMG_HEIGHT)
captcha_text = random_captcha_text(char_set)
captcha_str = " ".join(captcha_text)
captcha_str_no_space = "".join(captcha_text)
img_path = os.path.join(img_save_path, f"{index:04d}.jpg")
image_generator.write(captcha_str_no_space, img_path)
with open(label_path, "a", encoding="utf-8") as f:
f.write(f"{captcha_str}\n")
def generate_dataset(total_num=220, char_type='digit'):
if char_type == 'digit':
char_set = number
img_save_path = IMG_SAVE_PATH_DIGIT
label_path = LABEL_PATH_DIGIT
else:
char_set = alphanum
img_save_path = IMG_SAVE_PATH_ALPHA
label_path = LABEL_PATH_ALPHA
if not os.path.exists(label_path):
for i in range(total_num):
gen_captcha_text_and_image(i, char_set, img_save_path, label_path)
print(f"{char_type}数据集生成完成!共{total_num}张图片")
else:
print(f"{char_type}标签文件已存在,跳过生成")`
======================== 3. 自定义数据集类 ========================
`class CaptchaDataset(Dataset):
def init(self, root_dir, label_file, transform=None, char_type='digit'):
self.root_dir = root_dir
self.transform = transform
self.char_type = char_type
self.char_set = number if char_type=='digit' else alphanum
if char_type == 'digit':
self.labels = np.loadtxt(label_file, dtype=np.int64)
else:
with open(label_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.labels = [line.strip().replace(' ','') for line in lines]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, f"{idx:04d}.jpg")
image = Image.open(img_name).convert("RGB")
if self.char_type == 'digit':
label = torch.tensor(self.labels[idx], dtype=torch.long)
else:
label = [self.char_set.index(c) for c in self.labels[idx]]
label = torch.tensor(label, dtype=torch.long)
if self.transform:
image = self.transform(image)
return image, label`
======================== 4. 3种核心算法 ========================
--- 算法1:CNN模型
`class ConvNet(nn.Module):
def init(self, out_dim=CAPTCHA_SIZE*10):
super(ConvNet, self).init()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=4, stride=1, padding=2),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.fc_layers = nn.Sequential(
nn.Linear(64 * 7 * 20, 500),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(500, out_dim)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x`
--- 算法2:Sigmoid激活 全连接神经网络 ---
`class SigmoidMLP(nn.Module):
def init(self, out_dim=CAPTCHA_SIZE*10):
super(SigmoidMLP, self).init()
self.input_dim = 3 * IMG_HEIGHT * IMG_WIDTH
self.fc_layers = nn.Sequential(
nn.Linear(self.input_dim, 2048),
nn.BatchNorm1d(2048),
nn.Sigmoid(),
nn.Dropout(0.2),
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.Sigmoid(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.Sigmoid(),
nn.Linear(512, out_dim)
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x`
--- 算法3: Tanh激活 全连接神经网络 ---
`class TanhMLP(nn.Module):
def init(self, out_dim=CAPTCHA_SIZE*10):
super(TanhMLP, self).init()
self.input_dim = 3 * IMG_HEIGHT * IMG_WIDTH
self.fc_layers = nn.Sequential(
nn.Linear(self.input_dim, 2048),
nn.BatchNorm1d(2048),
nn.Tanh(),
nn.Dropout(0.2),
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.Tanh(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.Tanh(),
nn.Linear(512, out_dim)
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x`
======================== 5. 训练函数 ========================
def train_model(model, dataloader, out_dim, char_size, save_path): criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999)) model.to(DEVICE) model.train() for epoch in range(EPOCHS): total_loss = 0.0 for images, labels in dataloader: images = images.to(DEVICE) labels = labels.to(DEVICE).long() outputs = model(images) loss = criterion(outputs.view(-1, char_size), labels.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) avg_loss = total_loss / len(dataloader.dataset) print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}") torch.save(model.state_dict(), save_path) print(f"模型已保存至: {save_path}") return model
======================== 6. 四项指标计算函数 (准确率/精确率/召回率/F1) ========================
`def calculate_metrics(y_true, y_pred):
acc = np.mean(y_true == y_pred)
unique_labels = np.unique(np.concatenate([y_true, y_pred]))
precision_list = []
recall_list = []
for label in unique_labels:
TP = np.sum((y_true == label) & (y_pred == label))
FP = np.sum((y_true != label) & (y_pred == label))
FN = np.sum((y_true == label) & (y_pred != label))
precision = TP / (TP + FP) if (TP + FP) != 0 else 0.0
recall = TP / (TP + FN) if (TP + FN) != 0 else 0.0
precision_list.append(precision)
recall_list.append(recall)
prec = np.mean(precision_list)
rec = np.mean(recall_list)
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) != 0 else 0.0
return acc, prec, rec, f1
def eval_model(model, dataloader, char_size):
model.eval()
all_true = []
all_pred = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(DEVICE)
labels = labels.to(DEVICE).long()
outputs = model(images)
pred = torch.argmax(outputs.view(-1, char_size), dim=1)
all_true.extend(labels.view(-1).cpu().numpy())
all_pred.extend(pred.cpu().numpy())
all_true = np.array(all_true)
all_pred = np.array(all_pred)
acc, prec, rec, f1 = calculate_metrics(all_true, all_pred)
return acc, prec, rec, f1`
======================== 7. 绘图函数 ========================
plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams['figure.dpi'] = 180 plt.rcParams['font.size'] = 11 algo_colors = ['#1f77b4', '#d62728', '#2ca02c'] algo_names = ['CNN', 'Sigmoid全连接', 'Tanh全连接'] def plot_single_metric_compare(metrics_dict, metric_name, task_name, save_path): values = [metrics_dict[algo][metric_name] for algo in algo_names] plt.figure(figsize=(9, 6)) bars = plt.bar(algo_names, values, color=algo_colors, alpha=0.9, edgecolor='black', linewidth=0.8, width=0.6) for bar, val in zip(bars, values): plt.text(bar.get_x() + bar.get_width()/2, bar.get_height()+0.003, f'{val:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold') plt.title(f'{task_name} - {metric_name} 算法对比图 (学号:{STUDENT_ID})', pad=20, fontsize=13) plt.ylabel(f'{metric_name} (取值范围:0~1)', fontsize=11) plt.xlabel('算法模型', fontsize=11) plt.ylim(0, 1.05) plt.grid(axis='y', linestyle='--', alpha=0.5, linewidth=0.6) ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) plt.tight_layout() plt.savefig(save_path) plt.close()
批量生成4个指标的独立对比图
def generate_all_separate_metrics(metrics_dict, task_name, save_prefix): metrics_list = ['准确率', '精确率', '召回率', 'F1值'] for metric in metrics_list: plot_single_metric_compare(metrics_dict, metric, task_name, os.path.join(DATA_ROOT, f'{save_prefix}_{metric}_对比图.png')) print(f"✅ {task_name} → 4个指标分开对比图 全部生成完成!")
======================== 8. 核心任务执行 ========================
`def run_task(char_type='digit', total_num=220):
if char_type == 'digit':
char_size = 10
out_dim = CAPTCHA_SIZE * 10
img_path = IMG_SAVE_PATH_DIGIT
label_path = LABEL_PATH_DIGIT
task_name = "4位纯数字验证码识别"
save_prefix = "纯数字"
else:
char_size = 62
out_dim = CAPTCHA_SIZE * 62
img_path = IMG_SAVE_PATH_ALPHA
label_path = LABEL_PATH_ALPHA
task_name = "4位字母数字混合验证码识别"
save_prefix = "字母数字混合"
generate_dataset(total_num, char_type)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
dataset = CaptchaDataset(img_path, label_path, transform, char_type)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
cnn_model = ConvNet(out_dim)
sigmoid_mlp = SigmoidMLP(out_dim)
tanh_mlp = TanhMLP(out_dim)
metrics_dict = {
'CNN':{},
'Sigmoid全连接':{},
'Tanh全连接':{}
}
print(f"\n===== {task_name} - 训练 CNN =====")
cnn_model = train_model(cnn_model, train_loader, out_dim, char_size, os.path.join(DATA_ROOT, f'{save_prefix}_CNN.pth'))
cnn_acc, cnn_prec, cnn_rec, cnn_f1 = eval_model(cnn_model, test_loader, char_size)
metrics_dict['CNN'] = {'准确率':cnn_acc, '精确率':cnn_prec, '召回率':cnn_rec, 'F1值':cnn_f1}
print(f"\n===== {task_name} - 训练 Sigmoid全连接 =====")
sigmoid_mlp = train_model(sigmoid_mlp, train_loader, out_dim, char_size, os.path.join(DATA_ROOT, f'{save_prefix}_SigmoidMLP.pth'))
s_acc, s_prec, s_rec, s_f1 = eval_model(sigmoid_mlp, test_loader, char_size)
metrics_dict['Sigmoid全连接'] = {'准确率':s_acc, '精确率':s_prec, '召回率':s_rec, 'F1值':s_f1}
print(f"\n===== {task_name} - 训练 Tanh全连接 =====")
tanh_mlp = train_model(tanh_mlp, train_loader, out_dim, char_size, os.path.join(DATA_ROOT, f'{save_prefix}_TanhMLP.pth'))
t_acc, t_prec, t_rec, t_f1 = eval_model(tanh_mlp, test_loader, char_size)
metrics_dict['Tanh全连接'] = {'准确率':t_acc, '精确率':t_prec, '召回率':t_rec, 'F1值':t_f1}
generate_all_separate_metrics(metrics_dict, task_name, save_prefix)
print(f"\n===== {task_name} - 三种算法 最终四项指标 (学号:{STUDENT_ID}) =====")
for algo, metric in metrics_dict.items():
print(f"{algo} → 准确率:{metric['准确率']:.4f} | 精确率:{metric['精确率']:.4f} | 召回率:{metric['召回率']:.4f} | F1值:{metric['F1值']:.4f}")`
======================== 9. 推理函数 ========================
`def predict_captcha(img_path, model_weight_path, char_type='digit'):
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
char_set = number if char_type'digit' else alphanum
char_size = 10 if char_type'digit' else 62
out_dim = CAPTCHA_SIZE * char_size
model = ConvNet(out_dim)
model.load_state_dict(torch.load(model_weight_path, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()
image = Image.open(img_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(image_tensor)
outputs = outputs.view(-1, CAPTCHA_SIZE, char_size)
pred_indices = torch.argmax(outputs, dim=2).squeeze().cpu().numpy()
pred_captcha = "".join([char_set[int(i)] for i in pred_indices])
plt.figure(figsize=(6,4))
plt.imshow(image)
plt.title(f"验证码预测结果: {pred_captcha} (学号:{STUDENT_ID})", fontsize=12)
plt.axis('off')
plt.tight_layout()
plt.show()
print(f"推理完成 → 验证码预测值:{pred_captcha}")
return pred_captcha`
======================== 10. 主函数 ========================
`if name == "main":
# 任务1:纯数字验证码识别 + 4个指标分开对比图
print("="75)
print("【任务1】4位纯数字验证码识别 - 三种优化算法 分开指标对比")
print("="75)
run_task(char_type='digit', total_num=220)
# 任务2:字母数字混合验证码识别 + 4个指标分开对比图
print("\n" + "="*75)
print("【任务2】4位字母数字混合验证码识别 - 三种算法 分开指标对比")
print("="*75)
run_task(char_type='alphanum', total_num=220)
# 推理测试
predict_captcha(os.path.join(IMG_SAVE_PATH_DIGIT, "0000.jpg"), os.path.join(DATA_ROOT, "纯数字_CNN.pth"), 'digit')`
浙公网安备 33010602011771号