dataset类实战

本文是《小土堆》课程的笔记之dataset类实战

from torch.utils.data import Dataset
from  PIL  import Image
import  os

class MyData(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.image_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = "练手数据集/train"
ants_label_dir = "ants_inage"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

train_dataset = ants_dataset + bees_dataset

 

posted @ 2022-04-22 18:33  今天天气好极了  阅读(61)  评论(0)    收藏  举报