点击查看代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import random
from PIL import Image, ImageDraw, ImageFont
import numpy as np
# -------------------------- 1. 自动生成手写汉字数据集(兼容所有Pillow版本) --------------------------
class ChineseHandwritingDatasetGenerator:
def __init__(self, save_root="./generated_chinese_mnist", num_classes=15,
train_samples=3000, test_samples=600, img_size=64):
self.save_root = save_root
self.num_classes = num_classes
self.train_samples = train_samples
self.test_samples = test_samples
self.img_size = img_size
self.classes_per_sample = train_samples // num_classes
self.test_per_sample = test_samples // num_classes
# 15个常用汉字类别
self.CHINESE_CHARACTERS = [
'零', '一', '二', '三', '四', '五', '六', '七', '八', '九',
'十', '百', '千', '万', '亿'
]
# 配置中文字体(Windows系统默认支持)
self.font = self._get_chinese_font()
def _get_chinese_font(self, font_size=40):
"""获取Windows系统默认中文字体(兼容所有Pillow版本)"""
windows_fonts = [
"C:/Windows/Fonts/simhei.ttf", # 黑体(优先)
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
"C:/Windows/Fonts/simsun.ttc" # 宋体
]
for font_path in windows_fonts:
if os.path.exists(font_path):
try:
return ImageFont.truetype(font_path, font_size)
except Exception as e:
print(f"加载字体 {font_path} 失败:{e}")
continue
raise Exception("未找到可用的中文字体!请检查Windows字体目录是否存在simhei.ttf等文件。")
def _generate_handwriting_char(self, char):
"""生成手写风格汉字图像(无版本依赖)"""
# 创建白色背景图像(RGBA格式)
img = Image.new("RGBA", (self.img_size, self.img_size), (255, 255, 255, 255))
draw = ImageDraw.Draw(img)
# 随机字体大小(32-48),模拟手写粗细差异
font_size = random.randint(32, 48)
font = self._get_chinese_font(font_size)
# 计算文字居中位置(兼容所有Pillow版本)
try:
char_bbox = draw.textbbox((0, 0), char, font=font)
char_width = char_bbox[2] - char_bbox[0]
char_height = char_bbox[3] - char_bbox[1]
except AttributeError:
char_width, char_height = draw.textsize(char, font=font)
x = (self.img_size - char_width) // 2 + random.randint(-3, 3) # 随机偏移
y = (self.img_size - char_height) // 2 + random.randint(-3, 3)
# 随机深灰色文字(模拟手写笔颜色)
gray = random.randint(0, 60)
draw.text((x, y), char, font=font, fill=(gray, gray, gray, 255))
# 1. 随机旋转(-5°~5°)
angle = random.randint(-5, 5)
img = img.rotate(angle, expand=False, fillcolor=(255, 255, 255, 255))
# 2. 添加随机噪声(模拟手写粗糙感)
img_np = np.array(img)
noise = np.random.normal(0, 8, size=img_np.shape[:2]).astype(np.int16)
img_np[..., :3] = np.clip(img_np[..., :3] + noise[..., None], 0, 255)
img = Image.fromarray(img_np.astype(np.uint8))
# 3. 随机剪切(手动构造仿射矩阵,无版本依赖)
shear_angle = random.uniform(-0.1, 0.1)
# 仿射矩阵参数:(a, b, c, d, e, f) → 对应水平剪切
affine_matrix = (1.0, shear_angle, 0.0, 0.0, 1.0, 0.0)
img = img.transform(
size=(self.img_size, self.img_size),
method=Image.AFFINE,
data=affine_matrix,
resample=Image.BILINEAR if hasattr(Image, 'BILINEAR') else Image.NEAREST,
fillcolor=(255, 255, 255)
)
# 转为RGB格式(去除透明度,适配模型输入)
return img.convert("RGB")
def generate_and_save(self):
"""生成并保存train/test数据集"""
# 创建目录结构
for split in ["train", "test"]:
split_root = os.path.join(self.save_root, split)
for class_idx in range(self.num_classes):
os.makedirs(os.path.join(split_root, str(class_idx)), exist_ok=True)
# 生成训练集
print(f"正在生成训练集({self.train_samples}个样本)...")
for class_idx, char in enumerate(self.CHINESE_CHARACTERS):
for i in range(self.classes_per_sample):
img = self._generate_handwriting_char(char)
save_path = os.path.join(
self.save_root, "train", str(class_idx), f"char_{class_idx}_{i}.png"
)
img.save(save_path)
# 生成测试集
print(f"正在生成测试集({self.test_samples}个样本)...")
for class_idx, char in enumerate(self.CHINESE_CHARACTERS):
for i in range(self.test_per_sample):
img = self._generate_handwriting_char(char)
save_path = os.path.join(
self.save_root, "test", str(class_idx), f"char_{class_idx}_{i}.png"
)
img.save(save_path)
print(f"数据集生成完成!保存路径:{self.save_root}")
# -------------------------- 2. 生成数据集并加载 --------------------------
try:
generator = ChineseHandwritingDatasetGenerator(
save_root="./generated_chinese_mnist",
train_samples=3000, # 每个类别200个训练样本
test_samples=600 # 每个类别40个测试样本
)
generator.generate_and_save()
except Exception as e:
print(f"数据集生成失败:{e}")
exit(1)
CHINESE_CHARACTERS = generator.CHINESE_CHARACTERS
# 数据预处理(适配64x64三通道图像)
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
data_root = "./generated_chinese_mnist"
train_dataset = datasets.ImageFolder(os.path.join(data_root, "train"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_root, "test"), transform=transform)
# 数据加载器(Windows系统num_workers=0避免多线程报错)
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)
# -------------------------- 3. 汉字识别CNN模型 --------------------------
class ChineseCharacterCNN(nn.Module):
def __init__(self, num_classes=15):
super().__init__()
self.conv_layers = nn.Sequential(
# 卷积块1:3→64通道,64x64→32x32
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 卷积块2:64→128通道,32x32→16x16
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 卷积块3:128→256通道,16x16→8x8
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 8 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.conv_layers(x)
x = self.fc_layers(x)
return x
# 设备配置(GPU优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChineseCharacterCNN(num_classes=15).to(device)
# -------------------------- 4. 模型训练与测试 --------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
num_epochs = 15
best_acc = 0.0
print(f"\n开始训练(设备:{device})...")
for epoch in range(num_epochs):
# 训练阶段
model.train()
total_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
scheduler.step() # 学习率衰减
# 测试阶段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, preds = torch.max(outputs, 1)
total += labels.size(0)
correct += (preds == labels).sum().item()
acc = 100 * correct / total
print(f"Epoch [{epoch + 1:2d}/{num_epochs}], Loss: {avg_loss:.4f}, Test Acc: {acc:.2f}%")
# 保存最佳模型
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), "best_chinese_cnn.pth")
print(f" → 保存最佳模型(准确率:{best_acc:.2f}%)")
print(f"\n训练完成!最佳测试准确率:{best_acc:.2f}%")
# -------------------------- 5. 可视化测试结果 --------------------------
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示
plt.rcParams['axes.unicode_minus'] = False
model.eval()
dataiter = iter(test_loader)
imgs, labels = next(dataiter)
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, preds = torch.max(outputs, 1)
# 转换为CPU用于绘图
imgs = imgs.cpu()
labels = labels.cpu()
preds = preds.cpu()
# 绘制前6个样本
fig, axes = plt.subplots(1, 6, figsize=(15, 5))
for i in range(6):
# 反归一化图像
img = imgs[i].permute(1, 2, 0)
img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
img = torch.clamp(img, 0, 1)
axes[i].imshow(img)
true_char = CHINESE_CHARACTERS[labels[i]]
pred_char = CHINESE_CHARACTERS[preds[i]]
axes[i].set_title(f"真实:{true_char}\n预测:{pred_char}", fontsize=12)
axes[i].axis('off')
plt.tight_layout()
plt.show()
# -------------------------- 6. 单张手写汉字预测(可选) --------------------------
def predict_single_img(img_path):
"""预测单张手写汉字图片"""
transform_single = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
try:
from PIL import Image
img = Image.open(img_path).convert("RGB")
img_tensor = transform_single(img).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(img_tensor)
_, pred = torch.max(output, 1)
pred_char = CHINESE_CHARACTERS[pred.item()]
# 显示结果
plt.imshow(img)
plt.title(f"预测结果:{pred_char}", fontsize=14)
plt.axis('off')
plt.show()
except Exception as e:
print(f"预测失败:{e}")
# 示例:替换为你的手写汉字图片路径
# predict_single_img("my_handwriting.png")