手写汉字
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
import os
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 5
EPOCHS = 5
LR = 0.001
IMG_SIZE = (64, 64)
CLASSES = ['三', '五', '的', '和', '好']
NUM_CLASSES = len(CLASSES)
GENERATE_SAMPLES = 5
def generate_chinese_handwriting(char, img_size=(64, 64)):
img = Image.new('L', img_size, 255)
draw = ImageDraw.Draw(img)
try:
font_path = "C:/Windows/Fonts/simhei.ttf"
font = ImageFont.truetype(font_path, 38)
except:
font = ImageFont.load_default(size=35)
bbox = draw.textbbox((0, 0), char, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
x = (img_size[0] - text_width) // 2
y = (img_size[1] - text_height) // 2
draw.text((x, y), char, font=font, fill=0)
img_array = np.array(img)
noise = np.random.choice([0, 255], size=img_array.shape, p=[0.003, 0.997])
img_array = np.where(noise == 0, 0, img_array)
return Image.fromarray(img_array, mode='L')
class ChineseHandwritingDataset(Dataset):
def init(self, classes, num_samples, transform=None):
self.classes = classes
self.num_samples = num_samples
self.transform = transform
self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
self.samples = self._generate_all_samples()
def _generate_all_samples(self):
samples = []
print("生成5个手写汉字样本(三、五、的、和、好)...")
for i in range(self.num_samples):
char = self.classes[i]
label = self.class_to_idx[char]
img = generate_chinese_handwriting(char, IMG_SIZE)
samples.append((img, label))
print("样本生成完成!")
return samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
img, label = self.samples[idx]
if self.transform:
img = self.transform(img)
return img, label
transform = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.8,), (0.2,))
])
train_dataset = ChineseHandwritingDataset(CLASSES, GENERATE_SAMPLES, transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
class HandwritingCNN(nn.Module):
def init(self, num_classes):
super(HandwritingCNN, self).init()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Linear(64 * 8 * 8, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = HandwritingCNN(NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
def train_model():
model.train()
train_loss_list = []
train_acc_list = []
print(f"\n开始训练(设备:{DEVICE})")
for epoch in range(EPOCHS):
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100 * correct / total
train_loss_list.append(epoch_loss)
train_acc_list.append(epoch_acc)
print(f"===== Epoch [{epoch+1}/{EPOCHS}] | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}% =====")
torch.save(model.state_dict(), "chinese_handwriting_model.pth")
print("\n训练完成!模型已保存")
return train_loss_list, train_acc_list
def plot_train_curve(loss_list, acc_list):
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
ax1.plot(range(1, EPOCHS+1), loss_list, 'r-', linewidth=2, marker='o', label='训练损失')
ax1.set_title('训练损失曲线', fontsize=14)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.grid(True, alpha=0.3)
ax1.legend()
ax2.plot(range(1, EPOCHS+1), acc_list, 'b-', linewidth=2, marker='s', label='训练准确率')
ax2.set_title('训练准确率曲线', fontsize=14)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_ylim(0, 105)
ax2.grid(True, alpha=0.3)
ax2.legend()
plt.tight_layout()
plt.savefig('train_curve.png', dpi=150, bbox_inches='tight')
plt.show()
def visualize_generated_samples():
dataiter = iter(train_loader)
images, labels = next(dataiter)
images_vis = images * 0.2 + 0.8
images_vis = images_vis.cpu().detach()
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
axes = axes.flatten()
for idx, ax in enumerate(axes):
img = images_vis[idx].squeeze(0).numpy()
char = CLASSES[labels[idx].item()]
ax.imshow(img, cmap='gray', vmin=0, vmax=1)
ax.set_title(f"手写汉字:{char}", fontsize=12)
ax.axis('off')
plt.tight_layout()
plt.show()
if name == "main":
visualize_generated_samples()
loss_list, acc_list = train_model()
plot_train_curve(loss_list, acc_list)

浙公网安备 33010602011771号