Softmax 笔记
以下是 Markdown 格式 的完整代码与详细注释(保持代码逻辑不变,仅优化排版与可读性):
Softmax 回归简洁实现(d2l==1.0.3 + PyTorch)
基于 Fashion-MNIST 数据集,使用 PyTorch 手动实现数据加载、模型定义、训练与评估,解决 d2l 版本兼容问题,同时保留完整训练日志输出。
1. 导入依赖库
import torch # PyTorch核心库:张量操作、深度学习基础
from torch import nn # 神经网络模块:含层结构、损失函数等
from torch.utils.data import DataLoader # 数据加载器:批量处理数据
from torchvision import datasets, transforms # 计算机视觉工具:数据集+数据转换
from d2l import torch as d2l # D2L工具库:提供累加器、精度计算等辅助函数
2. 数据加载与预处理
核心目标
- 加载 Fashion-MNIST 数据集(训练集+测试集)
- 转换图像格式为张量并归一化
- 批量加载数据,避免 Windows 多进程冲突
# 数据转换规则:将图像转为Tensor(自动归一化像素值到 [0,1])
transform = transforms.Compose([transforms.ToTensor()])
batch_size = 256 # 每次训练/测试的样本数量(批量大小)
# 加载训练数据集
train_dataset = datasets.FashionMNIST(
root='./data', # 数据保存路径(当前目录下的data文件夹)
train=True, # 标记为训练集
download=True, # 本地无数据时自动下载(约30MB)
transform=transform # 应用上述数据转换
)
# 加载测试数据集(train=False 标记为测试集,其余参数同训练集)
test_dataset = datasets.FashionMNIST(
root='./data',
train=False,
download=True,
transform=transform
)
# 创建训练数据加载器(打乱数据顺序,单进程加载)
train_iter = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True, # 训练集需打乱:增强随机性,提升泛化能力
num_workers=0 # 单进程加载:解决Windows多进程冲突问题
)
# 创建测试数据加载器(无需打乱数据,单进程加载)
test_iter = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False, # 测试集无需打乱:仅需评估模型性能
num_workers=0
)
3. 模型定义与参数初始化
模型结构
Softmax 回归的核心是 “展平层 + 全连接层”:
- 展平层(
nn.Flatten()):将 28×28 的二维图像转为 784 维一维向量 - 全连接层(
nn.Linear(784, 10)):将 784 维输入映射到 10 维输出(对应 10 个类别)
# 定义Softmax回归模型(使用Sequential按顺序堆叠层)
net = nn.Sequential(
nn.Flatten(), # 展平层:(batch_size, 1, 28, 28) → (batch_size, 784)
nn.Linear(784, 10) # 全连接层:输入784维,输出10维(10个类别)
)
# 定义模型参数初始化函数
def init_weights(m):
"""
初始化全连接层的权重:均值为0,标准差为0.01的正态分布
参数 m: 模型中的单个层(如nn.Linear)
"""
if type(m) == nn.Linear: # 仅对全连接层执行初始化
nn.init.normal_(m.weight, std=0.01) # 权重初始化
# 偏置默认初始化为0,无需额外操作
# 对模型中所有层应用初始化函数
net.apply(init_weights)
4. 损失函数与优化器
关键选择
- 损失函数:
nn.CrossEntropyLoss
自动结合 Softmax 激活 与 交叉熵损失,避免手动计算 Softmax 导致的数值不稳定(如指数上溢/下溢)。 - 优化器:
torch.optim.SGD
随机梯度下降(SGD),学习率lr=0.1(经典参数,适合小批量训练)。
# 定义损失函数:CrossEntropyLoss(含Softmax)
# reduction='none':返回每个样本的损失(不自动求平均/求和)
loss = nn.CrossEntropyLoss(reduction='none')
# 定义优化器:SGD(随机梯度下降)
# net.parameters():需优化的模型参数(权重+偏置)
# lr=0.1:学习率(控制参数更新幅度)
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
5. 模型评估函数(计算准确率)
d2l==1.0.3 移除了 evaluate_accuracy,手动实现功能完全一致的评估逻辑:
- 切换模型到 评估模式(关闭 Dropout、批量归一化等训练特有的层)
- 关闭梯度计算(节省内存,避免反向传播干扰)
- 累加正确预测数与总样本数,最终返回准确率
def evaluate_accuracy(net, data_iter):
"""
计算模型在指定数据集上的准确率
参数:
net: 待评估的模型
data_iter: 数据集加载器(如test_iter)
返回:
准确率(正确预测数 / 总样本数)
"""
# 判断模型是否为PyTorch标准Module(确保可切换模式)
if isinstance(net, torch.nn.Module):
net.eval() # 切换到评估模式(关闭训练特有的层)
# 累加器:记录2个值——正确预测数、总样本数
metric = d2l.Accumulator(2)
# 关闭梯度计算(评估阶段无需反向传播)
with torch.no_grad():
for X, y in data_iter: # 遍历数据集中的每个批次
# 计算当前批次的正确预测数,并累加到metric
# d2l.accuracy(net(X), y):比较模型输出与真实标签,返回正确数
# y.numel():当前批次的总样本数(numel()=number of elements)
metric.add(d2l.accuracy(net(X), y), y.numel())
# 准确率 = 正确预测数 / 总样本数
return metric[0] / metric[1]
6. 模型训练循环(核心逻辑)
训练流程
- 遍历每个训练轮次(
num_epochs=10) - 每轮遍历所有训练批次:
- 梯度清零 → 前向传播(算预测)→ 算损失 → 反向传播(算梯度)→ 更新参数
- 每轮结束后,计算训练集损失、训练集准确率、测试集准确率
- 打印日志,监控训练进度
num_epochs = 10 # 训练总轮次(整个数据集重复训练10次)
for epoch in range(num_epochs): # 遍历每个轮次
# 累加器:记录3个值——总损失、正确预测数、总样本数
metric = d2l.Accumulator(3)
net.train() # 切换到训练模式(开启训练特有的层,当前模型无此类层但为规范写法)
# 遍历训练集中的每个批次
for X, y in train_iter:
trainer.zero_grad() # 梯度清零:避免上一轮梯度累积
y_hat = net(X) # 前向传播:输入X,得到模型预测y_hat(10维logits)
l = loss(y_hat, y) # 计算当前批次的损失(每个样本的损失)
l.mean().backward() # 损失求均值后反向传播:计算参数梯度
trainer.step() # 优化器更新参数:根据梯度调整权重和偏置
# 关闭梯度计算,累加当前批次的指标
with torch.no_grad():
metric.add(
l.sum(), # 累加当前批次的总损失
d2l.accuracy(y_hat, y), # 累加当前批次的正确预测数
y.numel() # 累加当前批次的总样本数
)
# 计算当前轮次的关键指标
train_loss = metric[0] / metric[2] # 平均训练损失 = 总损失 / 总样本数
train_acc = metric[1] / metric[2] # 训练集准确率 = 正确预测数 / 总样本数
test_acc = evaluate_accuracy(net, test_iter) # 测试集准确率
# 打印训练日志(格式化输出,保留3位小数)
print(f"epoch {epoch + 1:2d} | loss: {train_loss:.3f} | train_acc: {train_acc:.3f} | test_acc: {test_acc:.3f}")
7. 预期输出
训练过程中会打印如下日志(准确率略有波动,最终测试集准确率约 85%):
epoch 1 | loss: 0.785 | train_acc: 0.746 | test_acc: 0.796
epoch 2 | loss: 0.571 | train_acc: 0.812 | test_acc: 0.821
epoch 3 | loss: 0.525 | train_acc: 0.825 | test_acc: 0.833
epoch 4 | loss: 0.501 | train_acc: 0.832 | test_acc: 0.838
epoch 5 | loss: 0.485 | train_acc: 0.836 | test_acc: 0.841
epoch 6 | loss: 0.474 | train_acc: 0.840 | test_acc: 0.843
epoch 7 | loss: 0.465 | train_acc: 0.843 | test_acc: 0.845
epoch 8 | loss: 0.458 | train_acc: 0.845 | test_acc: 0.847
epoch 9 | loss: 0.452 | train_acc: 0.847 | test_acc: 0.849
epoch 10 | loss: 0.447 | train_acc: 0.849 | test_acc: 0.851
浙公网安备 33010602011771号