1 import torch
2 import os,glob
3 import random,csv
4
5 from torch.utils.data import Dataset,DataLoader
6
7 from torchvision import transforms
8 from PIL import Image
9
10 class Pokemon(Dataset):
11 '''
12 @param
13 root:存储的根路径
14 resize:将图片大小根据网络结构适配
15 mode:train或者test模式
16 '''
17 def __init__(self,root,resize,mode):
18 super(Pokemon,self).__init__()
19
20 self.root = root
21 self.resize = resize
22
23 # 字典类型key:name value:label
24 self.name2label = {}
25 # listdir返回顺序不固定,用sorted将它固定,因为排序一次之后就固定了
26 for name in sorted(os.listdir(os.path.join(root))):
27 if not os.path.isdir(os.path.join(root,name)):
28 continue
29
30 self.name2label[name] = len(self.name2label.keys())
31
32 # print(self.name2label)
33
34 # image_path + image_label
35 self.images,self.labels = self.load_csv('images.csv')
36
37 if mode == 'train': # 60%
38 self.images = self.images[:int(0.6*len(self.images))]
39 self.labels = self.labels[:int(0.6*len(self.labels))]
40 elif mode == 'val': # 20%
41 self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
42 self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
43 elif mode == 'test': # 20% = 80% ->100%
44 self.images = self.images[int(0.8*len(self.images)):]
45 self.labels = self.labels[int(0.8*len(self.labels)):]
46
47 def load_csv(self,filename):
48
49 # 如果不存在再写入,存在的话直接读取就可以了
50 if not os.path.exists(os.path.join(self.root,filename))
51 images = []
52 for name in self.name2label.keys():
53 # 'pokemon'\\mewtwo\\00001.png
54 images += glob.glob(os.path.join(self.root,name,'*.png'))
55 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
56 images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
57
58 # 1167,'pokemon\\bulbasaur\\00000000.png'
59 print(len(images),images)
60
61 random.shuffle(images)
62 with open(os.path.join(self.root,filename),mode = 'w',newline='') as f:
63 writer = csv.writer(f)
64 for img in images: # 'pokemon\\bulbasaur\\00000000.png'
65 name = img.split(os.sep)[-2]
66 label = self.name2label[name]
67 # 'pokemon\\bulbasaur\\00000000.png',0
68 writer.writerow([img,label])
69 print('writen into csv file:',filename)
70
71 # read from csv file
72 images,labels = [],[]
73 with open(os.path.join(self.root,filename))
74 reader = csv.reader(f)
75 for row in reader:
76 # 'pokemon\\bulbasaur\\00000000.png',0
77 img,label = row
78 label = int(label)
79
80 images.append(img)
81 labels.append(label)
82
83 # 保证images和labels一一对应,长度相等
84 assert len(images) == len(labels)
85 return images,labels
86
87 def __len__(self):
88
89 return len(self.images)
90
91 def denormalize(self,x_hat):
92
93 mean=[0.485,0.456,0.406]
94 std=[0.229,0.224,0.225]
95
96 # x_hat = (x-mean)/std
97 # x = x_hat*std+mean
98 # x: [c,h,w]
99 # mean: [3] --> [3,1,1]
100 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
101 std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
102
103 x = x_hat*std + mean
104
105 return x
106
107
108 def __getitem__(self,idx):
109 # idx~[0~len(images)]
110 # self.images,self.labels
111 # img: pokemon\\bulbasaur\\00000000.png'
112 # label: 0
113 img,label = self.images[idx],self.labels[idx]
114
115 tf = transforms.Compose([
116 lambda x:Image.open(x).convert('RGB'), # string path --> image data
117 transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
118 transforms.RandomRotation(15),
119 transforms.CenterCrop(self.resize),
120 transforms.ToTensor(),
121 transforms.Normalize(mean=[0.485,0.456,0.406],
122 std=[0.229,0.224,0.225])
123 ])
124
125 img = tf(img)
126 label = torch.tensor(label)
127
128 return img,label