Loading

MNIST 数据集训练及测试代码

简单概念

MNIST 是一个手写体数字的图片数据集,统计了来自 250 个不同的人手写数字的图片。该数据集可以通过算法,实现机器对手写数字的识别。

img

训练代码

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn.functional as f
import torch.optim as optim
from time import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_accuracy = 0.0

# 数据集
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])  # 归一化, 均值和方差
train_dataset = datasets.MNIST(root='./dataset/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=2)
test_dataset = datasets.MNIST(root='./dataset/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, num_workers=2)


# 模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # -1 :自动获取 mini_batch
        x = f.relu(self.l1(x))
        x = f.relu(self.l2(x))
        x = f.relu(self.l3(x))
        x = f.relu(self.l4(x))
        return self.l5(x)  # 最后一层不做激活,不进行非线性变换


model = Net().to(device)

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# 训练
def train(curr_epoch):
    total_loss = 0.0
    total_batch = 0
    _start_time = time()
    model.train()
    for _, (inputs, target) in enumerate(train_loader, 0):
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()  # 反向传播
        optimizer.step()

        total_loss += loss.item()
        total_batch += 1

    if curr_epoch % 1 == 0:
        _end_time = time()
        print(f"epoch: {curr_epoch}, loss: {total_loss / 300 :.4f}, time: {_end_time - _start_time :.2f}s", end=', ')


def test():
    correct = 0
    total_size = 0
    model.eval()
    global best_accuracy

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total_size += labels.size(0)
            correct += (predicted == labels).sum().item()  # 张量之间的比较运算

    accuracy = 100 * correct / total_size
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), f'./minist_{accuracy:.3f}%.pth')

    print(f'accuracy: {accuracy:.3f}%', end=', ')


if __name__ == '__main__':
    start_time = time()

    for epoch in range(1, 101):
        train(epoch)
        if epoch % 1 == 0:
            test()
            end_time = time()
            print(f'total time: {end_time - start_time : 2f}s')

测试验证代码

读取本地图片进行文字识别。

import torch
from torchvision import transforms
import torch.nn.functional as f
from PIL import Image
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_accuracy = 0.0

transform = transforms.Compose([
    transforms.ToTensor(),  # 将 0 ~ 255 的像素值映射到 0 ~ 1 的范围内,并转化为 Tensor 格式
    transforms.Normalize((0.1307,), (0.3081,))
])  # 归一化,均值和方差


# 模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)  # -1:自动获取 mini_batch
        x = f.relu(self.l1(x))
        x = f.relu(self.l2(x))
        x = f.relu(self.l3(x))
        x = f.relu(self.l4(x))
        return self.l5(x)  # 最后一层不做激活,不进行非线性变换


model = Net().to(device)
model.load_state_dict(torch.load('./minist_98.010%.pth', map_location='cpu'))

image = Image.open('./2.png').convert("L")
image = image.resize((28, 28))
image = np.expand_dims(image, 0)  # 增加一个维度
image = np.expand_dims(image, 0)  # 再增加一个维度
image = 1 - image.astype(np.float32) / 255.0  # 归一化到 0 到 1,因为测试图片是白底黑字,但训练集是黑底白字,做一个反色 1.0-image
# print(image.shape)  # 1*1*28*28,batch,通道数,H,W
image = torch.from_numpy(image)  # 转成 Torch 的张量


def test():
    model.eval()

    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, dim=1)
        print(predicted.item())


if __name__ == '__main__':
    test()
posted @ 2023-05-03 16:52  滑稽果  阅读(94)  评论(0编辑  收藏  举报
浏览器标题切换
浏览器标题切换end