文字识别系统代码

点击查看代码
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import os
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#数据加载与预处理
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
def load_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms, download=True)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, test_dataset

#定义CNN模型
class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.output = nn.Linear(16 * 7 * 7, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        output = self.output(x)

        return output, x
    
#训练模型
model = MyCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
critirion = nn.CrossEntropyLoss()

def train(epochs, train_loader, optimizer, critirion):
    model.train()
    train_loss = []
    train_acc = []

    for epoch in range(epochs):
        start_time = time.time()
        running_loss = 0.0
        total = 0
        current = 0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(images)[0]
            loss = critirion(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            current += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * current / total
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)

        end_time = time.time()
        print(f"Epoch [{epoch}/{epochs}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%, Time: {end_time-start_time:.2f}s")

    return train_loss, train_acc

#模型评估
def test(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            output = model(images)[0]
            _, predicted = torch.max(output.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    accuracy = (all_preds == all_labels).mean() * 100
    print(f"测试准确率: {accuracy:.2f}%")

    print("分类效果评估:")
    target_names = [str(i) for i in range(10)]
    report = classification_report(all_labels, all_preds, target_names=target_names)
    print(report)

if __name__ == '__main__':
    print("(24信计2班 王晶莹 2024310143126)")
    print(f"device: {device}")

    train_loader, test_loader, test_dataset = load_data()
    epochs = 20
    train_loss, train_acc = train(epochs, train_loader, optimizer, critirion)
    test(model, test_loader)

    #绘制结果
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs+1), train_loss)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")

    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs+1), train_acc)
    plt.title("Training Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")

    plt.tight_layout()
    plt.show()

屏幕截图 2025-11-12 201725

posted @ 2025-11-12 20:20  等雾语  阅读(0)  评论(0)    收藏  举报