【LeNet5神经网络模型搭建、参数训练、结果测试】

搭建LeNet5网络模型(Model.py)

点击查看代码
import torch
import torch.nn as nn


class LeNet5Model(nn.Module):
    def __init__(self):
        super(LeNet5Model, self).__init__()

        self.sigmoid = nn.Sigmoid()

        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=2)
        self.s2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5))
        self.s4 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(5, 5))

        self.flatten = nn.Flatten()

        self.f6 = nn.Linear(in_features=120, out_features=84)
        self.output = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.sigmoid(self.c1(x))
        x = self.s2(x)
        x = self.sigmoid(self.c3(x))
        x = self.s4(x)
        x = self.sigmoid(self.c5(x))
        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)
        return x

# model = LeNet5Model()
# print(model)

训练网络并验证(train.py)

点击查看代码
import os
import torch
import torch.nn as nn
from Model import LeNet5Model
from torch.utils import data
from torchvision import datasets, transforms

data_transform = transforms.ToTensor()

train_data = datasets.MNIST('./data', train=True, transform=data_transform, download=True)
train_load = data.DataLoader(train_data, batch_size=16, shuffle=True)
test_data = datasets.MNIST('./data', train=False, transform=data_transform, download=True)
test_load = data.DataLoader(test_data, batch_size=16, shuffle=True)

Model = LeNet5Model()

#定义一个损失函数
Loss_fn = nn.CrossEntropyLoss()
#定义一个优化器
Optimizer = torch.optim.SGD(Model.parameters(), lr=1e-2, momentum=0.9)


def train(model, optimizer, loss_fn):
    print('训练:')
    loss, acc, n = 0.0, 0.0, 0.0
    for batch, (X, y) in enumerate(train_load):
        #正向计算
        output = model(X)
        #损失函数计算
        cur_loss = loss_fn(output, y)
        #最大概率所属分类
        _, predict = torch.max(output, dim=1)
        #一批次数据的准确度
        cur_acc = torch.sum(predict == y) / output.shape[0]

        #反向传播
        optimizer.zero_grad()#梯度清零
        cur_loss.backward()#误差传递
        optimizer.step()#权重更新

        loss += cur_loss.item()
        acc += cur_acc.item()
        n += 1
    print('loss:' + str(loss / n))
    print('acc:' + str(acc / n))


def val(model, loss_fn):
    print('验证:')
    #模型转为验证模式
    model.eval()
    loss, acc, n = 0.0, 0.0, 0.0
    #不更新权重
    with torch.no_grad():
        for batch, (X, y) in enumerate(test_load):
            output = model(X)

            cur_loss = loss_fn(output, y)

            _, predict = torch.max(output, dim=1)

            cur_acc = torch.sum(predict == y) / output.shape[0]

            loss += cur_loss.item()
            acc += cur_acc.item()
            n += 1
        print('loss:' + str(loss / n))
        print('acc:' + str(acc / n))
    return acc / n


epoch = 10
min_acc = 0
for t in range(epoch):
    print('第{}轮训练与验证:'.format(t+1))
    train(Model, Optimizer, Loss_fn)
    a = val(Model, Loss_fn)

    if a > min_acc:
        folder = 'save_model'
        if not os.path.exists(folder):
            os.mkdir(folder)
        min_acc = a
        print('save best model:')
        #保存权重文件
        torch.save(Model.state_dict(), "save_model/best_model.pth")
        print('The save is complete!')
print("Done!")

用GPU服务器训练好的权重文件:文件->(save_model.rar)

测试网络训练效果(test.py)

点击查看代码
import random

import torch
from Model import LeNet5Model
from torchvision import datasets, transforms

# 准备测试数据
data_transform = transforms.ToTensor()
Test_data = datasets.MNIST('./data', train=False, transform=data_transform, download=True)

# 模型加载已经训练好的权重
Model = LeNet5Model()
Model.load_state_dict(torch.load("./save_model/best_model.pth"))

show = transforms.ToPILImage()


def test(model):
    # 从测试集的前1000张图中随机选取
    i = random.randint(0, 1000)
    (X, y) = Test_data[i][0], Test_data[i][1]
    show(X).show()

    X = torch.unsqueeze(X, dim=0).float()
    output = model(X)
    predict, actual = torch.argmax(output[0]), y
    print('predict:{}, actual:{}'.format(predict, actual))


test(Model)
posted @ 2021-12-12 21:02  Touming  阅读(459)  评论(0)    收藏  举报