Dog vs Cat复现

文件夹格式

 

 

 

仿照着陈云的书,自己做了些修改

dataset:

注意自己写dataset时至少写__init__(self)     __getitem__(self, index)   __len__(self, )这几个函数

 __init__用于创建地址列表,transform等操作

 

 __getitem__返回第index个data和label      dataloader中会调用这个函数

__len__可以返回data的数量  

import os
from PIL import Image
from torch.utils import data
import numpy as np
import torchvision.transforms as T


class DogCat(data.Dataset):
    #获取所有图片地址,并根据训练、验证、测试划分数据
    def __init__(self, root, transforms = None, train = True, test = False):
        self.test = test
        #os.listdir(root)列出root里所有文件的名称,包括子文件夹,返回的是以字典序排序
        imgs = [os.path.join(root, img) for img in os.listdir(root)]
        if self.test:
            #split('.')以.为分界线划分string,返回所有划分的子string,-2表示倒数第二个
            imgs = sorted(imgs, key = lambda x : int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key = lambda x : int(x.split('.')[-2]))
        imgs_num = len(imgs)

        #划分数据集,验证:训练 = 3 : 7
        if self.test:
            self.imgs = imgs
        elif train:
            # self.imgs = imgs[:int(0.7 * imgs_num)]
            self.imgs = imgs
        else:
            self.imgs = imgs[int(0.7 * imgs_num) :]
        #注意测试集和验证集的区别
        if transforms is None:
            if self.test or not train:
                self.transforms = T.Compose([
                    T.Scale((224, 224)),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    T.Normalize( mean = [0.485, 0.456, 0.406],
                                 std = [0.229, 0.224, 0.225])]
                )
            else:
                #.RandomSizedCrop(size)随机长宽裁剪原始图片,并将裁剪后的图像resize成size * size大小
                self.transforms = T.Compose([T.Scale(256),T.RandomSizedCrop(224),T.ToTensor(),T.Normalize( mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])])
    def __getitem__(self, index):
        #用于返回图像,如果一个类中定义了__getitem__,则其对象P可以通过P[i]取值
        #如果是训练集或验证集label返回标签,如果是测试集label返回序号即可
        img_path = self.imgs[index]
        if self.test:
            label = int(img_path.split('.')[-2].split('/')[-1])
        else:
            label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
    #定义了__len__函数就能用len(P)返回P的“长度”
    def __len__(self):
        return len(self.imgs)

 

config:配置文件,用于保存一些超参数

class  DefaultConfig(object):
    model = 'AlexNet'
    class_num = 2
    train_data_root = 'C:/Users/Dell/PycharmProjects/DC/data/train'
    test_data_root = 'C:/Users/Dell/PycharmProjects/DC/data/test1/'
    checkpoint_root = 'C:/Users/Dell/PycharmProjects/DC/checkpoint'
    batch_size = 4
    use_gpy = True
    num_workers = 4
    print_freq = 20
    debug_file = '/tmp/debug'
    result_file = 'result.csv'

    num_epochs = 1
    lr = 0.1
    lr_decay = 0.95
    weight_decay = 1e-4

 

model:保存网络模型,此次用的上次写的AlexNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from config import DefaultConfig
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        opt = DefaultConfig()
        self.conv1 = nn.Conv2d(3, 96, 11, stride = 4)
        self.conv2 = nn.Conv2d(96, 256, 5, padding = 2)
        self.conv3 = nn.Conv2d(256, 384, 2, padding = 1)
        self.conv4 = nn.Conv2d(384, 384, 3, padding = 1)
        self.conv5 = nn.Conv2d(384, 256, 3, padding = 1)
        self.fc1 = nn.Linear((256 * 6 * 6), 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, opt.class_num)
    def forward(self, x):
        x = F.max_pool2d(F.local_response_norm(F.relu(self.conv1(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
        x = F.max_pool2d(F.local_response_norm(F.relu(self.conv2(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size = 3, stride = 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p = 0.5, training = True, inplace = True)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p = 0.5, training = True, inplace = True)
        x = self.fc3(x)
        return x

 

model:就是train文件

import torch.utils.data.dataloader as Dataloader
import torch
import fire
from model import AlexNet
from config import DefaultConfig
from data.dataset import DogCat
import torchnet.meter as meter
from tqdm import tqdm
from tensorboardX import SummaryWriter
import os

opt = DefaultConfig()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def train(**kwargs):
    tbwrite = SummaryWriter(logdir = opt.checkpoint_root)
    model = AlexNet()
    model.train()
    if 'checkpoint_dir' in kwargs:
        checkpoint = model.load(kwargs['checkpoint_dir'])
        model.load_state_dict(checkpoint['model'])
        print('model load finished')
    model.to(device)
    train_data = DogCat(root = opt.train_data_root)
    train_dataloader = Dataloader.DataLoader(train_data, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers)
    val_data = DogCat(root = opt.train_data_root, train = False)
    val_dataloader = Dataloader.DataLoader(val_data, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers)
    optimizer = torch.optim.Adam(model.parameters(), lr = opt.lr)
    loss = torch.nn.CrossEntropyLoss()
    loss_meter = meter.AverageValueMeter()
    previous_loss = 1e100
    for epoch in range(opt.num_epochs):
        total_loss = 0
        total_true = 0
        cnt = 0
        img_total = 0
        for i, (data, label) in tqdm(enumerate(train_dataloader)):
            data.to(device)
            label.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss_value = loss(output, label)
            loss_value.backward()
            optimizer.step()

            # loss_meter.add(loss_value.data[0])
        for data, label in val_dataloader:
            cnt += 1
            data.to(device)
            label.to(device)
            output = model(data)
            loss_value = loss(output, label)
            total_loss += loss_value.item()
            pred = torch.max(output, 1)[1]
            total_true += torch.sum(pred == label)
            img_total += len(label)
        loss_mean = float(total_loss) / float(cnt)
        accuracy = float(total_true) / float(img_total)
        print('loss_mean:{:.4f}, accuracy:{:.2f}'.format(loss_mean, accuracy))
        tbwrite.add_scalar('loss_mean', loss_mean, epoch)
        tbwrite.add_scalar('accuracy', accuracy, epoch)


        # if loss_meter.value()[0] > previous_loss:
        #     lr = lr * opt.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        # previous_loss = loss_meter.value()[0]



        if(epoch + 1) % opt.num_epochs == 0:
            state = {
                'model': model.state_dict()
            }
            checkpoint_path = os.path.join(opt.checkpoint_root, 'checkpoint_{}.pkl'.format(epoch + 1))
            torch.save(state, checkpoint_path)


if __name__ == '__main__':

    fire.Fire()
    train()

 

test:测试

from model import AlexNet
from data.dataset import DogCat
from config import DefaultConfig
from torch.utils.data.dataloader import DataLoader
import torch
import os
opt = DefaultConfig()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if __name__ == '__main__':
    model = AlexNet()

    model.eval()
    checkpoint = torch.load(os.path.join(opt.checkpoint_root, 'checkpoint_1.pkl'))
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    test_data = DogCat(root = opt.test_data_root, train = False, test = True)
    for i in range(len(test_data)):
        data, label = test_data[i]
        data = data.reshape([-1, 3, 224, 224])
        data.to(device)
        output = model(data)
        pred = torch.max(output, 1)[1]
        print(pred.item())

 

posted @ 2021-09-26 09:28  WTSRUVF  阅读(127)  评论(0编辑  收藏  举报