点此进入CSDN

点此添加QQ好友 加载失败时会显示




你的浏览器不支持播放哦!!nuttertools 您的浏览器不支持该播放!

pytorch 手写数字识别项目 增量式训练

dataset.py

 

'''
准备数据集
'''
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor,Compose,Normalize
import torchvision
import config

def mnist_dataset(train):
    func = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=(0.1307,),
            std = (0.3081,)
        )
    ])

    #准备Mnist数据集
    return MNIST(root="../mnist",train=train,download=False,transform=func)

def get_dataloader(train = True):
    mnist = mnist_dataset(train)
    batch_size = config.train_batch_size if train else config.test_batch_size
    return DataLoader(mnist,batch_size=batch_size,shuffle=True)

if __name__ == '__main__':
    for (images,labels) in get_dataloader():
        print(images.size())
        print(labels)
        break

 

  model.py

 

'''定义模型'''

import torch.nn as nn
import torch.nn.functional as F

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel,self).__init__()
        self.fc1 = nn.Linear(28*28,100)
        self.fc2 = nn.Linear(100,10)

    def forward(self,image):
        image_viwed = image.view(-1,28*28)
        fc1_out = self.fc1(image_viwed)
        fc1_out_relu = F.relu(fc1_out)
        out = self.fc2(fc1_out_relu)

        return F.log_softmax(out,dim=-1)

 

  config.py

 

'''
项目配置
'''
import torch

train_batch_size = 128
test_batch_size = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

 

  train.py

 

'''
进行模型的训练
'''
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
from tqdm import tqdm
import numpy as np
import torch
import os
from eval import eval

#实例化模型、优化器、损失函数
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(),lr=0.001)

if os.path.exists("./model/mnist_net.pt"):
    model.load_state_dict(torch.load("./model/mnist_net.pt"))
    optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))


#迭代训练

def train(epoch):
    train_dataloader = get_dataloader(train=True)
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
    total_loss = []
    for idx,(input,target) in bar:
        input = input.to(config.device)
        target = target.to(config.device)
        #梯度置为0
        optimizer.zero_grad()
        #计算得到预测值
        output = model(input)
        #得到损失
        loss = F.nll_loss(output,target)
        total_loss.append(loss.item())
        #反向传播,计算损失
        loss.backward()
        #参数更新
        optimizer.step()

        if idx%10 ==0:
            bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss)))
            torch.save(model.state_dict(),"model/mnist_net.pt")
            torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")

if __name__ == '__main__':
    for i in range(10):
        train(i)
        eval()

 

  eval.py

 

'''
进行模型的训练
'''
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
import numpy as np
import torch
import os




#迭代训练

def eval():
    # 实例化模型、优化器、损失函数
    model = MnistModel().to(config.device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    if os.path.exists("./model/mnist_net.pt"):
        model.load_state_dict(torch.load("./model/mnist_net.pt"))
        optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
    test_dataloader = get_dataloader(train=False)
    total_loss = []
    total_acc = []
    with torch.no_grad():
        for input,target in test_dataloader:
            input = input.to(config.device)
            target = target.to(config.device)
            #计算得到预测值
            output = model(input)
            #计算损失
            loss = F.nll_loss(output,target)
            #反向传播,计算损失
            total_loss.append(loss.item())
            #计算准确率
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())
    print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc)))

if __name__ == '__main__':
        eval()

 

  

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py
epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s]
test loss:0.17968503131142147,test acc:0.9453125
epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
test loss:0.12370304338916947,test acc:0.9624208860759493
epoch:2 idx:460,loss:0.10398845713577534:  99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s]
test loss:0.10385569722592077,test acc:0.9697389240506329
epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]
test loss:0.08691684670652015,test acc:0.9754746835443038
epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
test loss:0.0803159438309413,test acc:0.9760680379746836
epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s]
test loss:0.08102699166423158,test acc:0.9759691455696202
epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s]
test loss:0.07991968260347089,test acc:0.9769580696202531
epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s]
test loss:0.07767781678917288,test acc:0.9774525316455697
epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s]
test loss:0.07755146227494071,test acc:0.9773536392405063
epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s]
test loss:0.07112929566845863,test acc:0.9802215189873418

  接口interface.py

 

'''
进行模型的训练
'''
from models import MnistModel
from torch import optim
import config
import torch
import os
import cv2
import torchvision.transforms as transforms

tranform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
            mean=(0.1307,),
            std = (0.3081,)
        )])

# 实例化模型、优化器、损失函数
model = MnistModel()
optimizer = optim.Adam(model.parameters(), lr=0.01)

if os.path.exists("./model/mnist_net.pt"):
    model.load_state_dict(torch.load("./model/mnist_net.pt",map_location=lambda storage, loc: storage))
    optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt",map_location=lambda storage, loc: storage))

#预测接口
def interface(pic_path):
    img = cv2.imread(pic_path)
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = tranform(img_gray)
    # img = np.transpose(img, (2,0,1))
    img = img.unsqueeze(0)
    with torch.no_grad():
        input = img
        #计算得到预测值
        output = model(input)
        pred = output.max(dim=-1)[1]
        print("识别结果为:",pred[0].to("cpu").numpy())


if __name__ == '__main__':
    while True:
        path = input("请输入图片地址:")
        path = "./pic_test/"+path+".png"
        print(path)
        interface(path)

 

  

 

posted @ 2020-02-15 22:38  高颜值的殺生丸  阅读(615)  评论(0)    收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928