1027-pytorch之手写体识别

pytorch手写体识别

代码

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt

from torch_study.lesson5_minist_train.utils import plot_curve, plot_image, plt, one_hot

batch_size = 512

# step1. load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

#batch_size为一次训练多少,shuffle是否打散

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=False)

#查看数据维度
x,y = next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')


class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        #wx+b
        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        # x:[b,1,28,28]
        # h1=relu(xw1+b1)
        x=F.relu(self.fc1(x))
        # h2=relu(h1*w2+b2)
        x=F.relu(self.fc2(x))
        # h3=h2*w3+b3
        x=self.fc3(x)

        return x


net = Net()
# [w1,b1,w2,b2,w3,b3]  momentum动量
optimizer = optim.SGD(net.parameters(),lr=0.05,momentum=0.9)

train_loss = []

#对数据集迭代3次
for epoch in range(3):
    #从数据集中sample出一个batch_size图片
    for batch_idx ,(x,y) in enumerate(train_loader):

        #x:[b,1,28,28] ,y[512]
        #[b,1,28,28] => [b,feature]
        x=x.view(x.size(0),28*28)
        # => [b,10]
        out = net(x)
        #[b,10]
        y_onehot = one_hot(y)
        #loss = mse(out,y_onehot)
        loss = F.cross_entropy(out,y_onehot)
        #清零梯度
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #w'=w-lr*grad,更新梯度
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx %10 ==0:
            print(epoch,batch_idx,loss.item())

#绘制损失曲线
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]

#对测试集进行判断
total_corrrect=0
for x,y in test_loader:
    x=x.view(x.size(0),28*28)
    out=net(x)
    # out:[b,10] => pred: [b]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_corrrect+=correct

total_num = len(test_loader.dataset)
acc = total_corrrect / total_num
print('test acc:',acc)

x,y=next(iter(test_loader))

out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

结果

 

 

 

 

 

 

 

 模型提升

增加模型层数

调整loss损失计算函数

调整学习率,训练大小batch_size

posted @ 2021-10-27 23:06  清风紫雪  阅读(229)  评论(0编辑  收藏  举报