手写汉字识别
import torchvision.transforms as transforms
import numpy as np
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
print("3018") #
IMAGE_SIZE = 64
NUM_CLASSES = 100
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001
class HandwrittenChineseDataset(Dataset):
def init(self, num_samples=10000, transform=None):
self.num_samples = num_samples
self.transform = transform
self.images = np.random.rand(num_samples, 1, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
self.labels = np.random.randint(0, NUM_CLASSES, size=num_samples)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
if self.transform:
from PIL import Image
image = Image.fromarray(image.squeeze() * 255).convert('L')
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
full_dataset = HandwrittenChineseDataset(num_samples=10000, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
class ChineseCharCNN(nn.Module):
def init(self, num_classes=NUM_CLASSES):
super(ChineseCharCNN, self).init()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(128 * (IMAGE_SIZE // 8) * (IMAGE_SIZE // 8), 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = ChineseCharCNN(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for batch_idx, (images, labels) in enumerate(loader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if (batch_idx + 1) % 50 == 0:
print(f'批次 [{batch_idx + 1}/{len(loader)}], 损失: {loss.item():.4f}, '
f'批次准确率: {(100 * correct / total):.2f}%')
avg_loss = total_loss / len(loader)
accuracy = 100 * correct / total
return avg_loss, accuracy
def validate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_loss = total_loss / len(loader)
accuracy = 100 * correct / total
return avg_loss, accuracy
best_val_acc = 0.0
for epoch in range(EPOCHS):
print(f'\n第 {epoch + 1}/{EPOCHS} 轮训练')
print('-' * 50)
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
print(f'\n本轮总结:')
print(f'训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%')
print(f'验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%')
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_chinese_char_model.pth')
print(f'保存最佳模型,验证准确率: {best_val_acc:.2f}%')
print(f'\n训练完成!最佳验证准确率: {best_val_acc:.2f}%')
print("3018")
def predict_single_image(model, image, device):
model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(device)
output = model(image)
_, predicted = torch.max(output, 1)
return predicted.item()
sample_image, sample_label = val_dataset[0]
predicted_label = predict_single_image(model, sample_image, device)
print(f'\n预测示例:')
print(f'真实标签: {sample_label}, 预测标签: {predicted_label}')
浙公网安备 33010602011771号