pytorch 识别 MNIST 数据集实现数字识别
pytorch 识别 MNIST 数据集实现数字识别
零、导入数据集
结构如下:

import torch
import torchvision
from torch.utils.data import DataLoader
# 都是在0~1的喵
train_data = torchvision.datasets.MNIST(
root='dataset',
train=True,
transform=torchvision.transforms.ToTensor(),
download=False
)
test_data = torchvision.datasets.MNIST(
root='dataset',
train=False,
transform=torchvision.transforms.ToTensor(),
download=False
)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)
如何查看数据集?
# 显示训练集的图片
import matplotlib.pyplot as plt
id = 4000
features = train_data.data[id]
label = train_data.train_labels[id]
plt.imshow(features,cmap="gray")
plt.title(f"label: {label}")
效果如下:

一、搭建网络
默认的网络结构如下,在文章后面还会提供 CNN 版本的进行进阶替换。
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x): return self.net(x)
二、训练
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(1, 5+1):
# 训练
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
# 测试
model.eval()
correct, total = 0, 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(1)
total += y.size(0)
correct += (pred == y).sum().item()
print(f'Epoch {epoch}: test acc = {correct/total:.4f}')
三、其它
1. 关于优化器
1.1 SGD
非常感谢 AI,使用这个参数的 SGD 真的可以让学习率达到很高很高
optimizer = optim.SGD(model.parameters(), lr=0.1, # ① 足够大
momentum=0.9, # ② 冲量
weight_decay=5e-4) # ③ L2
Epoch 1: test acc = 0.9572
Epoch 2: test acc = 0.9683
Epoch 3: test acc = 0.9692
Epoch 4: test acc = 0.9735
Epoch 5: test acc = 0.9705
1.2 Adam
其实一开始是用 Adam
optimizer = optim.Adam(model.parameters(), lr=1e-3)
Epoch 1: test acc = 0.9464
Epoch 2: test acc = 0.9615
Epoch 3: test acc = 0.9687
Epoch 4: test acc = 0.9714
Epoch 5: test acc = 0.9729
2. 关于网络的结构的改造
默认情况下是这样的:
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x): return self.net(x)
我希望加入部分 CNN 的内容。然后我把网络替换成的内容。
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# 28x28 → 24x24 → 12x12
nn.Conv2d(1, 32, 5, padding=0), # 5x5 kernel
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# 12x12 → 8x8 → 4x4
nn.Conv2d(32, 64, 5),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(), # 64*4*4 = 1024
nn.Linear(64*4*4, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
return self.classifier(x)
Epoch 1: test acc = 0.9828
Epoch 2: test acc = 0.9890
Epoch 3: test acc = 0.9906
Epoch 4: test acc = 0.9911
Epoch 5: test acc = 0.9914

浙公网安备 33010602011771号