import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns

plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print("🚀 高精度手写汉字识别系统...")

class EnhancedChineseDataset(Dataset):
def init(self, data_dir, transform=None, train=True, num_samples=2000):
self.data_dir = data_dir
self.transform = transform
self.train = train
self.samples = []
self.chinese_digits = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九']
self._generate_enhanced_data(num_samples)

def _generate_enhanced_data(self, num_samples):
    print(f"生成 {num_samples} 个增强手写汉字样本...")
    for i in range(num_samples):
        label = i % 10
        chinese_char = self.chinese_digits[label]
        self.samples.append({
            'image': self._create_enhanced_image(chinese_char, label),
            'label': label,
            'char': chinese_char
        })

def _create_enhanced_image(self, char, label):
    img_size = 64
    img = np.zeros((img_size, img_size), dtype=np.uint8)
    
    if char == '零':
        center_x, center_y = img_size//2, img_size//2
        for angle in np.linspace(0, 2*np.pi, 100):
            x = int(center_x + 15 * np.cos(angle))
            y = int(center_y + 15 * np.sin(angle))
            if 0 <= x < img_size and 0 <= y < img_size:
                img[y-1:y+2, x-1:x+2] = 200 + random.randint(0, 55)
    
    elif char == '一':
        thickness = random.randint(3, 6)
        start_y = img_size//2 - thickness//2
        for i in range(thickness):
            y_pos = start_y + i
            length = random.randint(40, 50)
            start_x = (img_size - length) // 2
            for x in range(length):
                intensity = 150 + random.randint(0, 105)
                if x < 5 or x > length - 5:
                    intensity = max(100, intensity - 50)
                img[y_pos, start_x + x] = intensity
    
    elif char == '二':
        thickness = random.randint(3, 5)
        upper_y = img_size//3
        length1 = random.randint(35, 45)
        start_x1 = (img_size - length1) // 2
        for i in range(thickness):
            for x in range(length1):
                intensity = 160 + random.randint(0, 95)
                if x < 4 or x > length1 - 4:
                    intensity = max(110, intensity - 50)
                img[upper_y + i, start_x1 + x] = intensity
        
        lower_y = 2 * img_size//3
        length2 = random.randint(40, 48)
        start_x2 = (img_size - length2) // 2
        for i in range(thickness):
            for x in range(length2):
                intensity = 160 + random.randint(0, 95)
                if x < 4 or x > length2 - 4:
                    intensity = max(110, intensity - 50)
                img[lower_y + i, start_x2 + x] = intensity
    
    elif char == '三':
        thickness = random.randint(2, 4)
        positions = [img_size//4, img_size//2, 3*img_size//4]
        for pos_y in positions:
            length = random.randint(35, 45)
            start_x = (img_size - length) // 2
            for i in range(thickness):
                for x in range(length):
                    intensity = 170 + random.randint(0, 85)
                    if x < 3 or x > length - 3:
                        intensity = max(120, intensity - 50)
                    img[pos_y + i, start_x + x] = intensity
    
    else:
        self._create_complex_character(img, char, label)
    
    noise = np.random.normal(0, 20, (img_size, img_size)).astype(np.int16)
    img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
    
    from scipy.ndimage import gaussian_filter
    img = gaussian_filter(img.astype(float), sigma=0.7).astype(np.uint8)
    
    return Image.fromarray(img)

def _create_complex_character(self, img, char, label):
    img_size = img.shape[0]
    center_x, center_y = img_size//2, img_size//2
    
    if char == '四':
        thickness = random.randint(3, 5)
        box_size = 20
        img[center_y-box_size:center_y-box_size+thickness, 
            center_x-box_size:center_x+box_size] = 180 + random.randint(0, 75)
        img[center_y+box_size-thickness:center_y+box_size, 
            center_x-box_size:center_x+box_size] = 180 + random.randint(0, 75)
        img[center_y-box_size:center_y+box_size, 
            center_x-box_size:center_x-box_size+thickness] = 180 + random.randint(0, 75)
        img[center_y-box_size:center_y+box_size, 
            center_x+box_size-thickness:center_x+box_size] = 180 + random.randint(0, 75)
    
    elif char == '五':
        thickness = random.randint(3, 4)
        img[center_y-15:center_y-15+thickness, center_x-20:center_x+20] = 190
        img[center_y-15:center_y+15, center_x:center_x+thickness] = 190
        img[center_y+10:center_y+10+thickness, center_x-15:center_x+15] = 190
    
    else:
        self._draw_digit_shape(img, label)

def _draw_digit_shape(self, img, digit):
    img_size = img.shape[0]
    center_x, center_y = img_size//2, img_size//2
    
    if digit == 6:
        for angle in np.linspace(np.pi/2, 5*np.pi/2, 80):
            x = int(center_x + 12 * np.cos(angle))
            y = int(center_y + 12 * np.sin(angle))
            if 0 <= x < img_size and 0 <= y < img_size:
                img[y-1:y+2, x-1:x+2] = 200
    
    elif digit == 7:
        thickness = 4
        img[center_y-12:center_y-12+thickness, center_x-15:center_x+15] = 200
        for i in range(25):
            x = center_x + 15 - i
            y = center_y - 12 + i
            if 0 <= x < img_size and 0 <= y < img_size:
                img[y:y+thickness, x:x+thickness] = 200

def __len__(self):
    return len(self.samples)

def __getitem__(self, idx):
    sample = self.samples[idx]
    image = sample['image']
    label = sample['label']
    
    if self.transform:
        image = self.transform(image)
    
    return image, label

class HighAccuracyChineseCNN(nn.Module):
def init(self, num_classes=10):
super(HighAccuracyChineseCNN, self).init()

    self.features = nn.Sequential(
        nn.Conv2d(1, 64, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Dropout(0.3),
        
        nn.Conv2d(64, 128, 3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 128, 3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Dropout(0.3),
        
        nn.Conv2d(128, 256, 3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256, 256, 3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Dropout(0.4),
    )
    
    self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
    
    self.classifier = nn.Sequential(
        nn.Linear(256 * 4 * 4, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(1024, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )
    
    self._initialize_weights()

def _initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

def forward(self, x):
    x = self.features(x)
    x = self.adaptive_pool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

class LabelSmoothingLoss(nn.Module):
def init(self, classes=10, smoothing=0.1):
super(LabelSmoothingLoss, self).init()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.classes = classes

def forward(self, pred, target):
    pred = pred.log_softmax(dim=-1)
    with torch.no_grad():
        true_dist = torch.zeros_like(pred)
        true_dist.fill_(self.smoothing / (self.classes - 1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
    return torch.mean(torch.sum(-true_dist * pred, dim=-1))

def get_advanced_transforms():
train_transform = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)),
transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

return train_transform, test_transform

def train_high_accuracy_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

train_transform, test_transform = get_advanced_transforms()

train_dataset = EnhancedChineseDataset('./chinese_data', transform=train_transform, 
                                     train=True, num_samples=3000)
test_dataset = EnhancedChineseDataset('./chinese_data', transform=test_transform, 
                                    train=False, num_samples=600)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

model = HighAccuracyChineseCNN(num_classes=10).to(device)
criterion = LabelSmoothingLoss(classes=10, smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print("开始训练高精度模型...")
train_losses = []
train_accs = []
test_accs = []
best_accuracy = 0.0

for epoch in range(15):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        if batch_idx % 20 == 0:
            print(f'Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    test_acc, test_report = test_model_comprehensive(model, test_loader, device)
    scheduler.step()
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    print(f'Epoch {epoch+1}/15:')
    print(f'  训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%')
    print(f'  测试准确率: {test_acc:.2f}%')
    print(f'  当前学习率: {scheduler.get_last_lr()[0]:.6f}')
    
    if test_acc > best_accuracy:
        best_accuracy = test_acc
        torch.save(model.state_dict(), 'best_chinese_model.pth')
        print(f'  ✅ 保存最佳模型,准确率: {best_accuracy:.2f}%')

plot_detailed_training_history(train_losses, train_accs, test_accs)
model.load_state_dict(torch.load('best_chinese_model.pth'))
final_accuracy, final_report = test_model_comprehensive(model, test_loader, device)

print(f"\n🎯 最终测试准确率: {final_accuracy:.2f}%")
plot_confusion_matrix_comprehensive(model, test_loader, device)

return model, test_loader, final_accuracy

def test_model_comprehensive(model, test_loader, device):
model.eval()
correct = 0
total = 0
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = output.max(1)
        
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

accuracy = 100. * correct / total

class_accuracy = {}
chinese_digits = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九']
for i in range(10):
    class_correct = sum(1 for p, t in zip(all_predictions, all_targets) if p == i and t == i)
    class_total = sum(1 for t in all_targets if t == i)
    class_accuracy[chinese_digits[i]] = 100. * class_correct / class_total if class_total > 0 else 0

return accuracy, class_accuracy

def plot_detailed_training_history(train_losses, train_accs, test_accs):
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, 'b-', linewidth=2, marker='o', markersize=4)
plt.title('训练损失曲线', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(train_accs, 'g-', label='训练准确率', linewidth=2, marker='o', markersize=4)
plt.plot(test_accs, 'r-', label='测试准确率', linewidth=2, marker='s', markersize=4)
plt.title('准确率曲线', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('准确率 (%)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
best_test_acc = max(test_accs)
best_epoch = test_accs.index(best_test_acc) + 1
plt.bar(['最佳测试准确率'], [best_test_acc], color='orange', alpha=0.7)
plt.text(0, best_test_acc + 1, f'{best_test_acc:.2f}%', ha='center', va='bottom', fontweight='bold')
plt.title(f'最佳性能 (Epoch {best_epoch})', fontsize=14, fontweight='bold')
plt.ylabel('准确率 (%)')
plt.ylim(0, 100)

plt.tight_layout()
plt.savefig('detailed_training_history.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ 详细训练历史图已保存")

def plot_confusion_matrix_comprehensive(model, test_loader, device):
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = output.max(1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

cm = confusion_matrix(all_targets, all_predictions)
chinese_digits = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九']

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=chinese_digits, yticklabels=chinese_digits)
plt.title('混淆矩阵 - 手写汉字识别', fontsize=16, fontweight='bold')
plt.xlabel('预测标签', fontsize=12)
plt.ylabel('真实标签', fontsize=12)
plt.tight_layout()
plt.savefig('comprehensive_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ 综合混淆矩阵已保存")

if name == "main":
print("=" * 60)
print("🎯 高精度手写汉字识别系统")
print("=" * 60)

try:
    model, test_loader, final_accuracy = train_high_accuracy_model()
    
    print(f"\n🎉 训练完成!最终准确率: {final_accuracy:.2f}%")
    print("📁 生成的文件:")
    print("   - detailed_training_history.png (详细训练历史)")
    print("   - comprehensive_confusion_matrix.png (混淆矩阵)")
    print("   - best_chinese_model.pth (最佳模型)")
    
except Exception as e:
    print(f"❌ 发生错误: {e}")
    import traceback
    traceback.print_exc()

image

posted on 2025-11-13 19:46  雨水啊  阅读(2)  评论(0)    收藏  举报