文字识别系统
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
自定义数据集类
class CustomDataset(Dataset):
def init(self, data_directory, transform=None):
self.data_directory = data_directory
self.transform = transform
self.classes = os.listdir(data_directory)
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = []
for cls in self.classes:
cls_dir = os.path.join(data_directory, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
self.samples.append((img_path, self.class_to_idx[cls]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
改进的CNN模型(增加层数、添加BatchNorm和Dropout)
class ImprovedCNN(nn.Module):
def init(self, num_classes):
super(ImprovedCNN, self).init()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(128 * 4 * 4, 512) # 假设输入图像是32x32,可根据实际调整
self.dropout1 = nn.Dropout(0.5)
self.relu4 = nn.ReLU()
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
x = x.view(x.size(0), -1)
x = self.relu4(self.dropout1(self.fc1(x)))
x = self.fc2(x)
return x
早停类(防止过拟合)
class EarlyStopping:
def init(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
训练、验证、测试函数
def train_validate_test(data_directory, epochs=100):
# 增强的数据预处理
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = CustomDataset(os.path.join(data_directory, 'train'), transform=train_transform)
test_dataset = CustomDataset(os.path.join(data_directory, 'test'), transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
num_classes = len(train_dataset.classes)
model = ImprovedCNN(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) # 更换为Adam优化器,添加权重衰减
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # 余弦退火学习率调度
early_stopping = EarlyStopping(patience=10, verbose=True) # 早停机制
class_mapping = train_dataset.class_to_idx
# 训练循环
for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_loss = val_loss / len(test_loader)
scheduler.step()
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}, LR: {scheduler.get_last_lr()[0]:.6f}')
# 早停判断
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
# 加载最佳模型
model.load_state_dict(torch.load('checkpoint.pt'))
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')
return class_mapping, test_accuracy
调用函数
data_directory = 'D:/pytorch/shuzi' # 替换为你的数据目录
class_mapping, test_accuracy = train_validate_test(data_directory, epochs=100)
print("Class Mapping:", class_mapping)
print("Test Accuracy:", test_accuracy)

浙公网安备 33010602011771号