手写汉字识别

import torchvision.transforms as transforms
import numpy as np
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
print("3018") #

IMAGE_SIZE = 64
NUM_CLASSES = 100
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001

class HandwrittenChineseDataset(Dataset):
def init(self, num_samples=10000, transform=None):
self.num_samples = num_samples
self.transform = transform
self.images = np.random.rand(num_samples, 1, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
self.labels = np.random.randint(0, NUM_CLASSES, size=num_samples)

def __len__(self):
    return self.num_samples

def __getitem__(self, idx):
    image = self.images[idx]
    label = self.labels[idx]

    if self.transform:
        from PIL import Image
        image = Image.fromarray(image.squeeze() * 255).convert('L')
        image = self.transform(image)

    return image, torch.tensor(label, dtype=torch.long)

transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])

full_dataset = HandwrittenChineseDataset(num_samples=10000, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

class ChineseCharCNN(nn.Module):
def init(self, num_classes=NUM_CLASSES):
super(ChineseCharCNN, self).init()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.classifier = nn.Sequential(
        nn.Linear(128 * (IMAGE_SIZE // 8) * (IMAGE_SIZE // 8), 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

model = ChineseCharCNN(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0

for batch_idx, (images, labels) in enumerate(loader):
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

    if (batch_idx + 1) % 50 == 0:
        print(f'批次 [{batch_idx + 1}/{len(loader)}], 损失: {loss.item():.4f}, '
              f'批次准确率: {(100 * correct / total):.2f}%')

avg_loss = total_loss / len(loader)
accuracy = 100 * correct / total
return avg_loss, accuracy

def validate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

avg_loss = total_loss / len(loader)
accuracy = 100 * correct / total
return avg_loss, accuracy

best_val_acc = 0.0
for epoch in range(EPOCHS):
print(f'\n第 {epoch + 1}/{EPOCHS} 轮训练')
print('-' * 50)

train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)

print(f'\n本轮总结:')
print(f'训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%')
print(f'验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%')

if val_acc > best_val_acc:
    best_val_acc = val_acc
    torch.save(model.state_dict(), 'best_chinese_char_model.pth')
    print(f'保存最佳模型,验证准确率: {best_val_acc:.2f}%')

print(f'\n训练完成!最佳验证准确率: {best_val_acc:.2f}%')
print("3018")

def predict_single_image(model, image, device):
model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(device)
output = model(image)
_, predicted = torch.max(output, 1)
return predicted.item()

sample_image, sample_label = val_dataset[0]
predicted_label = predict_single_image(model, sample_image, device)
print(f'\n预测示例:')
print(f'真实标签: {sample_label}, 预测标签: {predicted_label}')

posted @ 2025-12-26 11:49  彭66  阅读(3)  评论(0)    收藏  举报