pytorch实现vgg11用cairf-10数据集训练

import time
from torch import nn,optim
import torch
from torch._C import dtype
import torchvision
import pickle as p
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as plimg
from PIL import Image
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)#[0]代表的是批,选择批然后展开。
def vgg(conv_arch,fc_features,fc_hidden_units=4096):
    net=nn.Sequential()
    for i,(num_convs,in_channels,out_channels) in enumerate(conv_arch):
        net.add_module("vgg_block_" + str(i+1),vgg_block(num_convs, in_channels, out_channels))
    net.add_module("vgg_block_" + str(6),FlattenLayer())
    net.add_module("fc",nn.Sequential(nn.Linear(fc_features,fc_hidden_units),nn.ReLU(),nn.Dropout(0.5),nn.Linear(fc_hidden_units,10)))
    return net
def vgg_block(num_convs,in_channels,out_channels):
    blk=[]
    for i in range(num_convs):
        if i==0:
            blk.append(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))
        else:
            blk.append(nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1))
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*blk)
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict    
conv_arch=((1, 3, 64//8), (1, 64//8, 128//8), (2, 128//8, 256//16), (2, 256//16, 512//16),(2,512//16,512//16))#第一个数代表vgg块的数量。
fc_features=512//16*1*1
fc_hidden=4096
net=vgg(conv_arch,fc_features,fc_hidden)
test1=unpickle('test_batch')
test_1=np.reshape(test1[b'data'],(10000,3,32,32))
test_lable=test1[b'labels']
train1=unpickle('data_batch_1')
train_1=np.reshape(train1[b'data'],(10000,3,32,32))
lable1=train1[b'labels']
train2=unpickle('data_batch_2')
train_2=np.reshape(train2[b'data'],(10000,3,32,32))
lable2=train2[b'labels']
train3=unpickle('data_batch_3')
train_3=np.reshape(train1[b'data'],(10000,3,32,32))
lable3=train3[b'labels']

train4=unpickle('data_batch_4')
train_4=np.reshape(train4[b'data'],(10000,3,32,32))
lable4=train4[b'labels']

train5=unpickle('data_batch_5')
train_5=np.reshape(train5[b'data'],(10000,3,32,32))
lable5=train5[b'labels']



train1=torch.Tensor(train_1)

train2=torch.Tensor(train_2)
train3=torch.Tensor(train_3)
train4=torch.Tensor(train_4)
train5=torch.Tensor(train_5)

lable1=torch.Tensor(lable1)

lable2=torch.Tensor(lable2)
lable3=torch.Tensor(lable3)
lable4=torch.Tensor(lable4)
lable5=torch.Tensor(lable5)

train_1=torch.cat((train1,train2,train3,train4,train5),0)
lable1=torch.cat((lable1,lable2,lable3,lable4,lable5),0)
train_1=torch.FloatTensor(train_1)


test_1=torch.FloatTensor(test_1)
test_lable=torch.tensor(test_lable,dtype=torch.long)
lable1=torch.tensor(lable1,dtype=torch.long)


batch_size=500
lr,num_epochs=0.00005,2500
optimizer=torch.optim.Adam(net.parameters(),lr=lr)
device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
net=net.to(device)
train_1=train_1.to(device)
lable1=lable1.to(device)
test_1=test_1.to(device)
test_lable=test_lable.to(device)
loss=torch.nn.CrossEntropyLoss()
for i in range(num_epochs):
    bach_count=0
    n,train_acc_sum,test_acc,train_l_sum,batch_count=0,0,0,0,0
    for j in range(100):
        train=train_1[j*batch_size:j*batch_size+batch_size]
        lable=lable1[j*batch_size:j*batch_size+batch_size]
        y_hat=net(train)
        l=loss(y_hat,lable)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        train_acc_sum += (y_hat.argmax(dim=1) ==lable).sum().cpu().item()
        n+=lable.shape[0]
        batch_count=batch_count+1
        train_l_sum += l.cpu().item()  
    for j in range(20):
        train=test_1[j*batch_size:j*batch_size+batch_size]
        lable=test_lable[j*batch_size:j*batch_size+batch_size]
        y_hat=net(train)
        test_acc += (y_hat.argmax(dim=1) ==lable).sum().cpu().item()
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'%(i+1, train_l_sum / batch_count,train_acc_sum / n, test_acc))
posted @ 2021-10-10 16:38  祥瑞哈哈哈  阅读(167)  评论(0)    收藏  举报