代码:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
设置随机种子,保证结果可复现
torch.manual_seed(42)
1. 数据加载和预处理
定义数据转换
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), # CIFAR-10的均值
(0.2023, 0.1994, 0.2010)) # CIFAR-10的标准差
])
加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(
testset, batch_size=128, shuffle=False, num_workers=2)
CIFAR-10类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
2. 构建网络
class CIFAR10Net(nn.Module):
def init(self):
super(CIFAR10Net, self).init()
# 卷积层
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
# 池化层
self.pool = nn.MaxPool2d(2, 2)
# 批归一化
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(256)
# 全连接层
self.fc1 = nn.Linear(256 * 4 * 4, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 10)
# Dropout层防止过拟合
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# 第一个卷积块:卷积 -> 批归一化 -> ReLU -> 池化
x = self.pool(F.relu(self.bn1(self.conv1(x))))
# 第二个卷积块
x = self.pool(F.relu(self.bn2(self.conv2(x))))
# 第三个卷积块
x = self.pool(F.relu(self.bn3(self.conv3(x))))
# 展平特征图
x = x.view(-1, 256 * 4 * 4)
# 全连接层
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
3. 初始化网络、损失函数和优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
net = CIFAR10Net()
net.to(device)
交叉熵损失函数
criterion = nn.CrossEntropyLoss()
Adam优化器
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
学习率调度器,每7个epoch学习率乘以0.1
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
4. 训练网络
def train(net, trainloader, criterion, optimizer, epochs=20):
start_time = time.time()
train_losses = []
train_accs = []
print("开始训练...")
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
net.train() # 训练模式
for i, data in enumerate(trainloader, 0):
# 获取输入数据
inputs, labels = data[0].to(device), data[1].to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播、计算损失、反向传播、参数更新
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 统计训练数据
running_loss += loss.item()
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 每200个批次打印一次信息
if i % 200 == 199:
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}')
running_loss = 0.0
# 计算每个epoch的平均损失和准确率
epoch_loss = running_loss / len(trainloader)
epoch_acc = 100 * correct / total
train_losses.append(epoch_loss)
train_accs.append(epoch_acc)
print(f'Epoch {epoch+1} - 损失: {epoch_loss:.4f}, 准确率: {epoch_acc:.2f}%')
# 更新学习率
scheduler.step()
print(f'训练完成,耗时: {time.time() - start_time:.2f}秒')
return train_losses, train_accs
训练20个epoch
train_losses, train_accs = train(net, trainloader, criterion, optimizer, epochs=20)
5. 测试网络
def test(net, testloader):
net.eval() # 评估模式
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad(): # 不需要计算梯度
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 按类别统计准确率
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
# 打印总体准确率
print(f'测试集准确率: {100 * correct / total:.2f}%')
# 打印每个类别的准确率
for i in range(10):
print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')
return 100 * correct / total
在测试集上评估
test_acc = test(net, testloader)
6. 保存模型
torch.save(net.state_dict(), 'cifar10_cnn.pth')
print("模型已保存为 'cifar10_cnn.pth'")
7. 绘制训练过程图表
plt.figure(figsize=(12, 4))
绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='训练损失')
plt.title('训练损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='训练准确率')
plt.axhline(y=test_acc, color='r', linestyle='--', label=f'测试准确率: {test_acc:.2f}%')
plt.title('准确率曲线')
plt.xlabel('Epoch')
plt.ylabel('准确率 (%)')
plt.legend()
plt.tight_layout()
plt.show()
运行结果:
浙公网安备 33010602011771号