点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 数据准备与预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 2. 定义神经网络模型
class HandwritingRecognizer(nn.Module):
def __init__(self):
super(HandwritingRecognizer, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc_layers = nn.Sequential(
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(-1, 64 * 7 * 7)
x = self.fc_layers(x)
return x
# 3. 初始化模型、损失函数和优化器
model = HandwritingRecognizer()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 模型训练函数
def train(model, train_loader, criterion, optimizer, epochs=1):
for epoch in range(epochs):
print(f"================开始第{epoch + 1}轮训练================") # 输出训练开始的标志
model.train() # 设置模型为训练模式
running_loss = 0.0
correct_train = 0
total_train = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 计算训练集准确度
_, predicted = torch.max(output.data, 1)
total_train += target.size(0)
correct_train += (predicted == target).sum().item()
# 打印训练集的损失和准确度
train_loss = running_loss / len(train_loader)
train_accuracy = 100 * correct_train / total_train
# 打印测试集的准确度
test_accuracy = test(model, test_loader)
# 输出训练和测试的结果
print(
f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%\n")
# 5. 模型测试函数
def test(model, test_loader):
model.eval() # 设置模型为评估模式
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100 * correct / total
# 执行训练和测试:训练1轮,然后在测试集上评估
train(model, train_loader, criterion, optimizer, epochs=3)
注意:
关于数据集:
MNIST(Modified National Institute of Standards and Technology database)是计算机视觉领域的经典公开数据集,由美国国家标准与技术研究院(NIST)整理发布,其样本划分是 “出厂即固定” 的:
训练集:包含 60000 张手写数字图片(数字 0-9,每张 28×28 灰度图),来源是美国人口普查局的员工手写样本。
测试集:包含 10000 张手写数字图片,来源是美国高中生的手写样本,与训练集来源不同,确保能客观评估模型的 “泛化能力”(对陌生数据的预测效果)。