数字识别
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont, ImageFilter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import platform
解决OMP冲突
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
自动选择设备(CPU/GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
配置参数
STUDENT_ID_LAST4 = "3029" # 学号后四位
CHARS = ['一', '二', '三', '人', '口', '手', '日', '月', '水']
TRAIN_NUM = 200
TEST_NUM = 50
IMG_SIZE = 64
DATA_SAVE_DIR = 'hanzi_data'
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.005
新增:噪声控制参数
NOISE_PROB = 0.3 # 噪点概率
BLUR_PROB = 0.2 # 模糊概率
ROTATE_RANGE = (-30, 30) # 扩大旋转范围
OFFSET_RANGE = (-5, 5) # 字符偏移范围
适配系统的字体路径
def get_default_font():
"""获取系统可用的中文字体"""
font_paths = [
'simsun.ttc', # Windows 宋体
'/System/Library/Fonts/PingFang.ttc', # macOS 苹方
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Linux 备用
'msyh.ttc' # Windows 微软雅黑
]
for path in font_paths:
try:
ImageFont.truetype(path, 30)
return path
except:
continue
return None
打印学号信息的函数
def print_student_id():
"""输出学号后四位信息"""
print("=" * 30)
print(f"学号后四位:{STUDENT_ID_LAST4}")
print("=" * 30)
-------------------------- 生成带噪声的汉字图像 --------------------------
class HanziDatasetGenerator:
def init(self):
self.default_font = ImageFont.load_default()
self.chinese_font_path = get_default_font()
print("提示:生成带噪声的汉字图像,降低识别准确率")
def _add_noise(self, img):
"""给图像添加随机噪点(修正概率逻辑)"""
img_array = np.array(img, dtype=np.uint8)
# 随机生成黑白噪点
noise = np.random.randint(0, 256, size=img_array.shape, dtype=np.uint8)
mask = np.random.random(img_array.shape) < NOISE_PROB
img_array[mask] = noise[mask]
return Image.fromarray(img_array)
def _generate_single_img(self, char):
"""生成带干扰的汉字图像(修正裁剪边界)"""
img = Image.new('L', (IMG_SIZE, IMG_SIZE), color=255) # 白底
draw = ImageDraw.Draw(img)
# 字符基础偏移(适配不同字符)
char_offsets = {
'一': (5, 25), '二': (5, 15), '三': (5, 10),
'人': (10, 20), '口': (15, 15), '手': (5, 10),
'日': (15, 15), '月': (10, 15), '水': (5, 10)
}
base_x, base_y = char_offsets.get(char, (10, 20)) # 兜底默认值
# 随机偏移(确保不超出画布)
x = max(0, min(IMG_SIZE - 20, base_x + random.randint(*OFFSET_RANGE)))
y = max(0, min(IMG_SIZE - 20, base_y + random.randint(*OFFSET_RANGE)))
# 绘制字符(随机字体大小)
font_size = random.randint(30, 45)
try:
if self.chinese_font_path:
font = ImageFont.truetype(self.chinese_font_path, size=font_size)
draw.text((x, y), char, font=font, fill=0)
else:
raise Exception("无可用中文字体")
except Exception as e:
print(f"使用默认字体:{e}")
draw.text((x, y), char, font=self.default_font, fill=0, stroke_width=1)
# 随机旋转
rotation = random.randint(*ROTATE_RANGE)
img = img.rotate(rotation, expand=False, fillcolor=255)
# 随机添加噪点
if random.random() < NOISE_PROB:
img = self._add_noise(img)
# 随机模糊
if random.random() < BLUR_PROB:
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))
# 随机裁剪(修正边界,避免负坐标)
crop_margin = random.randint(2, 8)
crop_box = (
crop_margin, crop_margin,
max(crop_margin + 10, IMG_SIZE - crop_margin),
max(crop_margin + 10, IMG_SIZE - crop_margin)
)
img = img.crop(crop_box)
img = img.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR)
return img
def generate_dataset(self):
"""生成带噪声的数据集"""
# 清空旧数据
if os.path.exists(DATA_SAVE_DIR):
for root, dirs, files in os.walk(DATA_SAVE_DIR, topdown=False):
for f in files:
os.remove(os.path.join(root, f))
for d in dirs:
os.rmdir(os.path.join(root, d))
# 创建目录
for split in ['train', 'test']:
for char in CHARS:
os.makedirs(os.path.join(DATA_SAVE_DIR, split, char), exist_ok=True)
print("生成带噪声的数据集...")
for char in CHARS:
# 生成训练集
for i in range(TRAIN_NUM):
img = self._generate_single_img(char)
img.save(os.path.join(DATA_SAVE_DIR, 'train', char, f'{i}.png'))
# 生成测试集
for i in range(TEST_NUM):
img = self._generate_single_img(char)
img.save(os.path.join(DATA_SAVE_DIR, 'test', char, f'{i}.png'))
print(f"数据集生成完成:{os.path.abspath(DATA_SAVE_DIR)}")
-------------------------- 数据集加载 --------------------------
class HanziDataset(Dataset):
def init(self, split='train'):
self.split = split
self.data_dir = os.path.join(DATA_SAVE_DIR, split)
self.char_list = CHARS
self.char2idx = {c: i for i, c in enumerate(self.char_list)}
self.images, self.labels = self._load_data()
# 统一图像预处理(和推理时保持一致)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 归一化,提升模型稳定性
])
def _load_data(self):
"""加载图像路径和标签"""
images = []
labels = []
for char in self.char_list:
char_dir = os.path.join(self.data_dir, char)
if not os.path.exists(char_dir):
continue
for img_name in os.listdir(char_dir):
if img_name.endswith('.png'):
images.append(os.path.join(char_dir, img_name))
labels.append(self.char2idx[char])
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = Image.open(self.images[idx]).convert('L')
return self.transform(img), self.labels[idx]
-------------------------- 轻量化模型(降低拟合能力) --------------------------
class FeatureCNN(nn.Module):
def init(self, num_classes=9):
super(FeatureCNN, self).init()
self.features = nn.Sequential(
nn.Conv2d(1, 4, kernel_size=3, padding=1), # 减少卷积核数量
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.2), # 添加Dropout防止过拟合
nn.Conv2d(4, 8, kernel_size=3, padding=1), # 减少卷积核数量
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.2), # 添加Dropout
)
self.classifier = nn.Linear(8 * 16 * 16, num_classes) # 81616是64/2/2=16的输出维度
def forward(self, x):
x = self.features(x)
x = x.view(-1, 8 * 16 * 16) # 展平
x = self.classifier(x)
return x
-------------------------- 训练与识别 --------------------------
def main():
# 程序启动时输出学号
print("汉字识别CNN模型作业")
print_student_id()
# 生成带噪声的数据集
generator = HanziDatasetGenerator()
generator.generate_dataset()
# 加载数据(适配Windows的num_workers)
num_workers = 0 if platform.system() == 'Windows' else 2
train_dataset = HanziDataset('train')
test_dataset = HanziDataset('test')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
# 模型与优化器
model = FeatureCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练
print(f"\n开始训练(使用{device},控制准确率不超过95%)...")
best_acc = 0.0
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
avg_loss = total_loss / len(train_dataset)
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, preds = torch.max(outputs, 1)
total += labels.size(0)
correct += (preds == labels).sum().item()
acc = 100 * correct / total
print(f"轮次 {epoch + 1:2d}/{EPOCHS} | 损失:{avg_loss:.4f} | 测试准确率:{acc:.2f}%")
# 保存最优模型
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), 'best_model.pth')
# 提前停止(避免准确率过高)
if acc >= 95:
print(f"准确率达到95%,提前停止训练")
break
# 训练完成后再次输出学号
print("\n训练完成!")
print_student_id()
print(f"最优测试准确率:{best_acc:.2f}%")
# 加载最优模型
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
# 交互式推理
print("\n请输入图片路径(输入q退出):")
while True:
path = input("> ")
if path.lower() == 'q':
break
if not os.path.exists(path):
print("路径错误,请重新输入")
continue
try:
# 图像预处理(和训练集保持一致)
img = Image.open(path).convert('L').resize((IMG_SIZE, IMG_SIZE))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
img_tensor = transform(img).unsqueeze(0).to(device) # 添加batch维度
with torch.no_grad():
output = model(img_tensor)
pred_idx = torch.argmax(output, dim=1).item()
pred_char = CHARS[pred_idx]
confidence = torch.softmax(output, dim=1)[0, pred_idx].item() * 100
# 识别结果中包含学号
print(f"学号后四位:{STUDENT_ID_LAST4} | 识别结果:{pred_char} | 可信度:{confidence:.2f}%")
except Exception as e:
print(f"识别失败:{str(e)}")
if name == "main":
main()
浙公网安备 33010602011771号