PyTorch教程【四】PyTorch加载数据

代码示例:

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
def init(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)

def getitem(self, item):
img_name = self.img_path[item]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label

def len(self):
return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

posted @ 2020-10-09 17:02  PT小陈  阅读(825)  评论(0)    收藏  举报