点击查看代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageFilter
import matplotlib.pyplot as plt
import random
import os
from torchvision import transforms
import matplotlib
# 确保中文显示正常
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
matplotlib.rcParams["axes.unicode_minus"] = False
# 1. 生成增强扰动的中文数据集
class ChineseCharacterDataset(Dataset):
def __init__(self, num_samples=10000, img_size=(64, 64), transform=None):
self.num_samples = num_samples
self.img_size = img_size
self.transform = transform
# 字符集(50个常用字)
self.chars = "一二三四五六七八九十甲乙丙丁戊己庚辛壬癸金木水火土天地日月风云雨雪山水田人上下前后左右大小多少"
self.classes = list(self.chars)
self.num_classes = len(self.classes)
self.char_to_idx = {char: i for i, char in enumerate(self.classes)}
self.images, self.labels = self._generate_data()
def _generate_data(self):
"""生成更接近手写的中文字符图像"""
images = []
labels = []
# 尝试加载多种中文字体,增加字体多样性
fonts = []
font_paths = [
"C:/Windows/Fonts/simhei.ttf", # 黑体
"C:/Windows/Fonts/simsun.ttc", # 宋体
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
"C:/Windows/Fonts/msyhbd.ttc", # 微软雅黑-bold
]
for path in font_paths:
try:
# 随机字体大小,增加多样性
font_size = random.randint(32, 42)
fonts.append(ImageFont.truetype(path, font_size))
except:
continue
# 如果没有加载到字体,使用默认字体
if not fonts:
fonts.append(ImageFont.load_default())
print("警告:未找到中文字体,使用默认字体")
for _ in range(self.num_samples):
char = random.choice(self.classes)
label = self.char_to_idx[char]
font = random.choice(fonts) # 随机选择字体
# 创建带浅灰背景的图像(更接近真实纸张)
bg_color = random.randint(240, 255) # 浅灰色背景
img = Image.new('L', self.img_size, color=bg_color)
draw = ImageDraw.Draw(img)
# 计算文字位置并添加更大幅度的随机偏移
bbox = draw.textbbox((0, 0), char, font=font)
char_width = bbox[2] - bbox[0]
char_height = bbox[3] - bbox[1]
# 更大的位置扰动
x = (self.img_size[0] - char_width) // 2 + random.randint(-8, 8)
y = (self.img_size[1] - char_height) // 2 + random.randint(-8, 8)
# 随机笔画颜色(模拟不同笔的浓度)
stroke_color = random.randint(0, 60) # 近黑色,略有变化
# 绘制文字
draw.text((x, y), char, font=font, fill=stroke_color)
# 增强手写效果的扰动
img = self._add_handwriting_effects(img)
images.append(img)
labels.append(label)
return images, labels
def _add_handwriting_effects(self, img):
"""添加多种手写风格效果"""
# 1. 随机笔画粗细变化(模拟笔压)
if random.random() > 0.3:
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.3, 0.8)))
# 2. 随机拉伸/压缩(模拟书写角度)
if random.random() > 0.5:
# 随机水平或垂直方向轻微变形
scale = random.uniform(0.9, 1.1)
width, height = img.size
if random.random() > 0.5:
# 水平方向变形
img = img.resize((int(width * scale), height), Image.Resampling.LANCZOS)
else:
# 垂直方向变形
img = img.resize((width, int(height * scale)), Image.Resampling.LANCZOS)
# 变形后可能超出边界,重新调整大小
img = img.resize(self.img_size, Image.Resampling.LANCZOS)
# 3. 更大角度的随机旋转
angle = random.randint(-15, 15)
img = img.rotate(angle, expand=False, fillcolor=random.randint(240, 255))
# 4. 更复杂的噪声(模拟纸张纹理和书写杂质)
if random.random() > 0.2:
img = self._add_complex_noise(img)
# 5. 笔画断裂效果(模拟快速书写)
if random.random() > 0.7:
img = self._add_stroke_breaks(img)
return img
def _add_complex_noise(self, img):
"""添加更接近真实手写的噪声"""
img_np = np.array(img)
# 随机点噪声
noise = np.random.randint(-30, 30, size=img_np.shape)
mask = np.random.choice([0, 1], size=img_np.shape, p=[0.95, 0.05]) # 5%的像素添加噪声
img_np = np.clip(img_np + noise * mask, 0, 255).astype(np.uint8)
# 纸张纹理(轻微灰度变化)
if random.random() > 0.5:
texture = np.random.randint(-10, 10, size=img_np.shape)
img_np = np.clip(img_np + texture, 0, 255).astype(np.uint8)
return Image.fromarray(img_np)
def _add_stroke_breaks(self, img):
"""模拟笔画断裂效果"""
img_np = np.array(img)
height, width = img_np.shape
# 随机生成一些白色线条,模拟笔画断裂
for _ in range(random.randint(1, 3)):
x1 = random.randint(0, width // 2)
y1 = random.randint(0, height)
x2 = random.randint(width // 2, width)
y2 = random.randint(0, height)
# 绘制白色线条(覆盖原有笔画)
for i in range(100):
x = int(x1 + (x2 - x1) * i / 100)
y = int(y1 + (y2 - y1) * i / 100)
if 0 <= x < width and 0 <= y < height:
# 线条宽度为1-2像素
img_np[y, x] = random.randint(230, 255)
if random.random() > 0.5 and y + 1 < height:
img_np[y + 1, x] = random.randint(230, 255)
return Image.fromarray(img_np)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
img = self.images[idx]
label = self.labels[idx]
if self.transform:
img = self.transform(img)
return img, label
# 2. 数据预处理与加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 创建数据集
train_dataset = ChineseCharacterDataset(
num_samples=15000,
img_size=(64, 64),
transform=transform
)
test_dataset = ChineseCharacterDataset(
num_samples=3000,
img_size=(64, 64),
transform=transform
)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 3. 增强模型
class EnhancedChineseCNN(nn.Module):
def __init__(self, num_classes):
super(EnhancedChineseCNN, self).__init__()
# 更深的卷积层以捕捉复杂特征
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.batch_norm1 = nn.BatchNorm2d(64)
self.batch_norm2 = nn.BatchNorm2d(128)
self.batch_norm3 = nn.BatchNorm2d(256)
# 调整全连接层
self.fc1 = nn.Linear(256 * 4 * 4, 1024) # 64x64经4次池化后为4x4
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, num_classes)
self.dropout1 = nn.Dropout(0.3) # 轻度dropout
self.dropout2 = nn.Dropout(0.5) # 中度dropout
def forward(self, x):
# 卷积 -> 批归一化 -> ReLU -> 池化
x = self.pool(F.relu(self.batch_norm1(self.conv1(x))))
x = self.pool(F.relu(self.batch_norm2(self.conv2(x))))
x = self.pool(F.relu(self.batch_norm3(self.conv3(x))))
x = self.pool(F.relu(self.batch_norm3(self.conv4(x))))
# 展平
x = x.view(-1, 256 * 4 * 4)
# 全连接层
x = F.relu(self.fc1(x))
x = self.dropout2(x)
x = F.relu(self.fc2(x))
x = self.dropout1(x)
x = self.fc3(x)
return x
# 创建模型实例
num_classes = train_dataset.num_classes
model = EnhancedChineseCNN(num_classes)
# 4. 定义训练参数
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
# 5. 模型训练(修改为10轮)
num_epochs = 10 # 训练轮次减少到10
model.train()
for epoch in range(num_epochs):
total_loss = 0
correct = 0
total = 0
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_accuracy = 100 * correct / total
avg_loss = total_loss / len(train_loader)
scheduler.step(avg_loss)
print(f"Epoch [{epoch + 1}/{num_epochs}], 损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%")
# 6. 模型测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_accuracy = 100 * correct / total
print(f"测试准确率: {test_accuracy:.2f}%")
# 7. 可视化测试结果
dataiter = iter(test_loader)
images, labels = next(dataiter)
outputs = model(images)
_, predictions = torch.max(outputs, 1)
classes = train_dataset.classes
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
for i in range(12):
row = i // 6
col = i % 6
img = images[i].numpy().squeeze()
img = (img * 0.5) + 0.5
axes[row, col].imshow(img, cmap='gray')
true_char = classes[labels[i]]
pred_char = classes[predictions[i]]
axes[row, col].set_title(f"真实: {true_char}\n预测: {pred_char}")
axes[row, col].axis('off')
plt.tight_layout()
plt.show()