lightweight读取lable和图片

class CocoTrainDataset(Dataset):
def __init__(self, labels, images_folder, stride, sigma, paf_thickness, transform=None):
super().__init__()
self._images_folder = images_folder
self._stride = stride
self._sigma = sigma
self._paf_thickness = paf_thickness
self._transform = transform
with open(labels, 'rb') as f:
self._labels = pickle.load(f)

def __getitem__(self, idx):
label = copy.deepcopy(self._labels[idx]) # label modified in transform
image = cv2.imread(os.path.join(self._images_folder, label['img_paths']), cv2.IMREAD_COLOR)
mask = np.ones(shape=(label['img_height'], label['img_width']), dtype=np.float32)
mask = get_mask(label['segmentations'], mask)
sample = {
'label': label,
'image': image,
'mask': mask
}
if self._transform:
sample = self._transform(sample)

 

pickle.load是读取标签。

cv2.imread是读取图片。

posted @ 2023-02-01 11:08  祥瑞哈哈哈  阅读(18)  评论(0)    收藏  举报