深度网络学习-PyTorch_自定义Datsset

PyTorch中的数据

Dataset Dataloader transformer
数据集的格式

分类生成标签

 制作训练和验证数据的.txt文件
 
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

import os 

def list_dir(path):
    res =dict()
    for category in os.listdir(path):
        temp_dir = os.path.join(path, category)
        if os.path.isdir(temp_dir):
            temp =os.listdir(temp_dir)
            leaf_file = [os.path.join("/",category,data) for data in temp]
            res[category]=leaf_file
    return res

def get_text(path,fil_dict):
    relation = {"dog":1,"cat":2}
    file_nm = os.path.split(path)[-1]+".txt"
    with open(os.path.join(path,file_nm),mode="w",encoding="utf-8") as f:
        for category_key in fil_dict:
            for label_file in  fil_dict[category_key]:
                labe_res= label_file +  "\t"+ str(relation[category_key] )
                print( labe_res  )
                f.write(labe_res+"\r\n")


if __name__ == '__main__':
    data_dir = "./pytorch/data/train"
    fil = list_dir(data_dir)
    get_text(data_dir,fil)

数据情况

分类-数据文件夹的结构
  ├── cat
  │   ├── 05.jpg
  │   ├── 06.jpg
  │   ├── 07.jpg
  │   └── 08.jpg
  ├── dog
  │   ├── 01.jpg
  │   └── 05.jpg
  └── train.txt

 其中 train.txt的内容是
    /dog/01.jpg     1
    /dog/05.jpg     1
    /cat/06.jpg     2
    /cat/07.jpg     2
    /cat/05.jpg     2
    /cat/08.jpg     2

自定义Dataset

 自定义Dataset,继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


# step1: 定义MyDataset类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
class MyDataset(Dataset):
    def __init__(self, root_dir, names_file, transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []

        if not os.path.isfile(self.names_file):
            print(self.names_file + ' ## does not exist!')
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        image_path = self.root_dir + self.names_list[idx].split('\t')[0]
        if not os.path.isfile(image_path):
            print(image_path +  '@does not exist!')
            return None
        image = cv2.imread(image_path) 

        label = int(self.names_list[idx].split('\t')[1])

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample

# # 变换Resize
class Resize(object):

    def __init__(self, output_size: tuple):
        self.output_size = output_size

    def __call__(self, sample):
        # 图像
        image = sample['image']
        # 对图像进行缩放
        image_new =  cv2.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}

# # 变换ToTensor
class ToTensor(object):

    def __call__(self, sample):
        image = sample['image']
        image_new = np.transpose(image, (2, 0, 1))
        return {'image': torch.from_numpy(image_new),
                'label': sample['label']}


if __name__ == "__main__":
    train_dataset = MyDataset(root_dir='./pytorch/data/train',
                          names_file='./pytorch/data/train/train.txt',
                          transform=transforms.Compose( [
						  Resize((224,224)),ToTensor()
						  ])
                          )
    for (cnt,i) in enumerate(train_dataset):
        image = i['image']
        label = i['label']
        print(label)
    trainset_dataloader = DataLoader(dataset=train_dataset,
                                 batch_size=4,
                                 shuffle=True,
                                 num_workers=4)
    
    for i_batch, sample_batch in enumerate(trainset_dataloader):
        images_batch, labels_batch = sample_batch['image'], sample_batch['label']
        print(labels_batch.shape,labels_batch.dtype)
        print(images_batch.shape,images_batch.dtype)
        print(labels_batch)
        print(images_batch)

参考

     https://pytorch.org/docs/stable/data.html
posted @ 2021-09-07 17:50  辰令  阅读(75)  评论(0)    收藏  举报