PyTorch,torch,DataLoader,Dataset,MNIST,loss,optimizer,datasets,transforms,matplotlib.pyplot,random_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,random_split
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np

DEVICE=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
BATCH_SIZE=64
EPOCHS=10
LEARNING_RATE=0.001

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

full_train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

train_size=int(0.7*len(full_train_dataset))
val_size=len(full_train_dataset)-train_size
train_dataset,val_dataset=random_split(full_train_dataset,[train_size,val_size])

train_loader=DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2)
val_loader=DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=2)
test_loader=DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=2)

class MNIST_CNN(nn.Module):
    def __init__(self):
        super(MNIST_CNN,self).__init__()
        self.conv1=nn.Conv2d(1,16,3,padding=1)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*14*14,10)

    def forward(self,x):
        x=self.pool(self.relu(self.conv1(x)))
        x=x.view(-1,16*14*14)
        x=self.fc1(x)
        return x

def train_one_epoch(model,train_loader,criterion,optimizer,device):
    model.train()
    total_loss=0.0
    correct=0
    total=0
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        #前向传播
        output=model(data)
        loss=criterion(output,target)

        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        _,predicted=torch.max(output.data,1)
        total+=target.size(0)
        correct+=(predicted==target).sum().item()

    
    train_loss=total_loss/len(train_loader)
    train_acc=100*correct/total
    return train_loss,train_acc

def evaluate(model,loader,criterion,device,is_test=False):
    model.eval()
    total_loss=0.0
    correct=0
    total=0
    with torch.no_grad():
        for data,target in loader:
            data,target=data.to(device),target.to(device)
            output=model(data)
            loss=criterion(output,target)

            total_loss+=loss.item()
            _,predicted=torch.max(output.data,1)
            total+=target.size(0)
            correct+=(predicted==target).sum().item()
    
    eval_loss=total_loss/len(loader)
    eval_acc=100*correct/total
    if is_test:
        print(f'eval dataset -loss:{eval_loss:.4f},acc:{eval_acc:.2f}%')
    return eval_loss,eval_acc

if __name__=='__main__':        
    model=MNIST_CNN().to(DEVICE)

    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(model.parameters(),lr=LEARNING_RATE)
    train_losses=[]
    train_accs=[]
    val_losses=[]
    val_accs=[]

    print('start training...')
    for epoch in range(EPOCHS):
        train_loss,train_acc=train_one_epoch(model,train_loader,criterion, optimizer,DEVICE)
        val_loss,val_acc=evaluate(model,val_loader,criterion,DEVICE)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch [{epoch+1}/{EPOCHS}] - train loss:{train_loss:.4f},acc:{train_acc:.2f}%| val_loss:{val_loss:.4f},val acc:{val_acc:.2f}%')

    test_loss,test_acc=evaluate(model,test_loader,criterion,DEVICE,is_test=True)

    torch.save(model.state_dict(),'mnist_cnn_model.pth')
    print(f'model saves in mnist_cnn_model.pth')


    plt.figure(figsize=(10,6))
    plt.plot(range(1,EPOCHS+1),train_accs,label='train_acc',marker='o')
    plt.plot(range(1,EPOCHS+1),val_accs,label='val acc',marker='s')

    plt.axhline(y=test_acc,color='r',linestyle='--',label=f'test acc:{test_acc:.2f}%')

    plt.xlabel('Epochs')
    plt.ylabel('Acc (%)')
    plt.title('MNIST Train/Validate/Test accuracy ployline')
    plt.legend()
    plt.grid(True)
    plt.xticks(range(1,EPOCHS+1))
    plt.ylim(90,100)
    # plt_manager=plt.get_current_fig_manager()
    # plt_manager.window.showMaximized()
    plt_manager = plt.get_current_fig_manager()
    if plt_manager.__class__.__name__ == 'TkAgg':
        plt_manager.window.state('zoomed')  # TkAgg(Windows默认)
    elif plt_manager.__class__.__name__ in ['QT4Agg', 'QT5Agg']:
        plt_manager.window.showMaximized()  # Qt后端
    else:
        plt_manager.resize(*plt_manager.window.maxsize())  # 其他后端
    
    plt.show()
    input()

 

 

 

image

 

posted @ 2026-01-14 22:49  FredGrit  阅读(0)  评论(0)    收藏  举报