import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm # 进度条,提升训练可视化体验
# ===================== 1. 基础配置 =====================
# 设置随机种子,保证结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 设备配置(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# ===================== 2. 数据预处理与划分 =====================
# 数据预处理:转为张量 + 归一化(MNIST像素值0-255,归一化到0-1)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST官方均值/标准差
])
# 下载MNIST数据集
full_train_dataset = datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
root='./data', train=False, download=True, transform=transform
)
# 划分训练集(70%)、验证集(15%)、测试集(15%)
# MNIST原始训练集60000条,测试集10000条,需重新划分整体数据
total_dataset = torch.utils.data.ConcatDataset([full_train_dataset, test_dataset])
total_size = len(total_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
total_dataset, [train_size, val_size, test_size]
)
# ===================== 3. 自定义数据集类 =====================
class CustomMNISTDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset # 接收划分后的数据集
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# 自定义数据集的核心:返回单条数据(特征+标签)
data, label = self.dataset[idx]
return data.to(device), torch.tensor(label, dtype=torch.long).to(device)
# 封装自定义数据集
train_custom_dataset = CustomMNISTDataset(train_dataset)
val_custom_dataset = CustomMNISTDataset(val_dataset)
test_custom_dataset = CustomMNISTDataset(test_dataset)
# 构建数据加载器(批量加载数据,支持多线程)
batch_size = 64
train_loader = DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
# ===================== 4. 定义深度学习模型 =====================
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
# 特征提取层:卷积+池化
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1), # 输入通道1(灰度图),输出32通道
nn.ReLU(), # 激活函数:ReLU(避免梯度消失)
nn.MaxPool2d(kernel_size=2), # 池化层,尺寸减半
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
# 分类层:全连接
self.classifier = nn.Sequential(
nn.Flatten(), # 展平特征图
nn.Linear(64 * 7 * 7, 128), # 7*7是池化后的尺寸,128隐藏层维度
nn.ReLU(),
nn.Dropout(0.5), # Dropout防止过拟合
nn.Linear(128, 10) # 输出10类(0-9数字)
)
def forward(self, x):
# 前向传播逻辑
x = self.features(x)
x = self.classifier(x)
return x
# 初始化模型并移至指定设备
model = MNISTNet().to(device)
# ===================== 5. 配置损失函数、优化器 =====================
criterion = nn.CrossEntropyLoss() # 损失函数:交叉熵(适合分类任务)
optimizer = optim.Adam(model.parameters(), lr=0.001) # 优化器:Adam(自适应学习率)
# 学习率调度器(可选,提升训练效果)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# ===================== 6. 训练、验证、测试循环 =====================
epochs = 100
# 记录训练过程中的精度
train_acc_list = []
val_acc_list = []
test_acc_list = []
# 早停机制(防止过拟合,可选)
best_val_acc = 0.0
patience = 10 # 连续10轮验证集精度不提升则停止
patience_counter = 0
for epoch in range(epochs):
# ---------------------- 训练阶段 ----------------------
model.train() # 切换训练模式(启用Dropout、BatchNorm等)
train_correct = 0
train_total = 0
train_loss = 0.0
# 使用tqdm显示训练进度条
train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
for data, labels in train_bar:
# 1. 梯度归零
optimizer.zero_grad()
# 2. 前向传播
outputs = model(data)
# 3. 计算损失
loss = criterion(outputs, labels)
# 4. 反向传播
loss.backward()
# 5. 更新参数
optimizer.step()
# 统计训练精度
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_loss += loss.item()
# 更新进度条显示
train_bar.set_postfix(loss=train_loss/train_total, acc=train_correct/train_total)
train_acc = train_correct / train_total
train_acc_list.append(train_acc)
# ---------------------- 验证阶段 ----------------------
model.eval() # 切换评估模式(禁用Dropout、BatchNorm等)
val_correct = 0
val_total = 0
val_loss = 0.0
with torch.no_grad(): # 禁用梯度计算,节省内存和时间
val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
for data, labels in val_bar:
outputs = model(data)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_bar.set_postfix(loss=val_loss/val_total, acc=val_correct/val_total)
val_acc = val_correct / val_total
val_acc_list.append(val_acc)
# ---------------------- 测试阶段(每轮epoch后测试) ----------------------
test_correct = 0
test_total = 0
with torch.no_grad():
for data, labels in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
test_acc = test_correct / test_total
test_acc_list.append(test_acc)
# 学习率调度器更新
scheduler.step()
# 早停判断
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
# 保存最佳模型
torch.save(model.state_dict(), 'best_mnist_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
print(f'早停触发!Epoch: {epoch+1}, 最佳验证精度: {best_val_acc:.4f}')
break
# 打印每轮epoch的结果
print(f'Epoch {epoch+1} | 训练精度: {train_acc:.4f} | 验证精度: {val_acc:.4f} | 测试精度: {test_acc:.4f}')
# ===================== 7. 加载最佳模型并最终测试 =====================
model.load_state_dict(torch.load('best_mnist_model.pth'))
model.eval()
final_test_correct = 0
final_test_total = 0
with torch.no_grad():
for data, labels in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
final_test_total += labels.size(0)
final_test_correct += (predicted == labels).sum().item()
final_test_acc = final_test_correct / final_test_total
print(f'\n最终测试精度: {final_test_acc:.4f}')
# ===================== 8. 精度可视化 =====================
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_acc_list)+1), train_acc_list, label='训练精度', marker='o')
plt.plot(range(1, len(val_acc_list)+1), val_acc_list, label='验证精度', marker='s')
plt.plot(range(1, len(test_acc_list)+1), test_acc_list, label='测试精度', marker='^')
plt.xlabel('Epoch')
plt.ylabel('精度')
plt.title('MNIST数据集训练/验证/测试精度变化')
plt.legend()
plt.grid(True)
plt.show()
python -m pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
python -m pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple