深度学习-加载数据集cv2.resize()报错

深度学习-加载数据集cv2.resize()报错

  • 假如你自认为数据集制作的很完美,那么就可能是图片读取异常的问题,图片存在,但当你试图打开这张图片,就会出现图片错误或者加载异常,因此,我们需要找到这张图片并去掉它.

代码实例(在yolo-fastestv2的datasets.py)

自定义一个读取数据集的函数filte()

  • 假如该图片无法被读取,就打印出它的名字
    def filte(self):
        for i in self.data_list:
            img = cv2.imread(i)
            if not hasattr(img, 'shape'):
                print(i)

对datasets.py加以修改并运行

  • 在main函数里将数据集路径改为自己的
import os
import cv2
import random
import numpy as np

import torch
from torch.utils import data
from torch.utils.data import Dataset

def contrast_and_brightness(img):
    alpha = random.uniform(0.25, 1.75)
    beta = random.uniform(0.25, 1.75)
    blank = np.zeros(img.shape, img.dtype)
    # dst = alpha * img + beta * blank
    dst = cv2.addWeighted(img, alpha, blank, 1-alpha, beta)
    return dst

def motion_blur(image):
    if random.randint(1,2) == 1:
        degree = random.randint(2,3)
        angle = random.uniform(-360, 360)
        image = np.array(image)
    
        # 这里生成任意角度的运动模糊kernel的矩阵, degree越大,模糊程度越高
        M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
        motion_blur_kernel = np.diag(np.ones(degree))
        motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
    
        motion_blur_kernel = motion_blur_kernel / degree
        blurred = cv2.filter2D(image, -1, motion_blur_kernel)
    
        # convert to uint8
        cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
        blurred = np.array(blurred, dtype=np.uint8)
        return blurred
    else:
        return image

def augment_hsv(img, hgain = 0.0138, sgain = 0.678, vgain = 0.36):
    r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
    hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
    dtype = img.dtype  # uint8

    x = np.arange(0, 256, dtype=np.int16)
    lut_hue = ((x * r[0]) % 180).astype(dtype)
    lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
    lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
    img = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)  # no return needed
    return img


def random_resize(img):
    h, w, _ = img.shape
    rw = int(w * random.uniform(0.8, 1))
    rh = int(h * random.uniform(0.8, 1))

    img = cv2.resize(img, (rw, rh), interpolation = cv2.INTER_LINEAR) 
    img = cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR) 
    return img

def img_aug(img):
    img = contrast_and_brightness(img)
    #img = motion_blur(img)
    #img = random_resize(img)
    #img = augment_hsv(img)
    return img

def collate_fn(batch):
    img, label = zip(*batch)
    for i, l in enumerate(label):
        if l.shape[0] > 0:
            l[:, 0] = i
    return torch.stack(img), torch.cat(label, 0)

class TensorDataset():
    def __init__(self, path, img_size_width = 352, img_size_height = 352, imgaug = False):
        assert os.path.exists(path), "%s文件路径错误或不存在" % path

        self.path = path
        self.data_list = []
        self.img_size_width = img_size_width
        self.img_size_height = img_size_height
        self.img_formats = ['bmp', 'jpg', 'jpeg', 'png']
        self.imgaug = imgaug

        # 数据检查
        with open(self.path, 'r') as f:
            for line in f.readlines():
                data_path = line.strip()
                if os.path.exists(data_path):
                    img_type = data_path.split(".")[-1]
                    if img_type not in self.img_formats:
                        raise Exception("img type error:%s" % img_type)
                    else:
                        self.data_list.append(data_path)
                else:
                    raise Exception("%s is not exist" % data_path)
        self.filte()

    def filte(self):
        for i in self.data_list:
            img = cv2.imread(i)
            if not hasattr(img, 'shape'):
                print(i)

    def __getitem__(self, index):
        img_path = self.data_list[index]
        label_path = img_path.split(".")[0] + ".txt"
        # print(img_path)

        # 归一化操作
        img = cv2.imread(img_path)
        # print(img_path)

        img = cv2.resize(img, (self.img_size_width, self.img_size_height), interpolation = cv2.INTER_LINEAR) 
        #数据增强
        if self.imgaug == True:
            img = img_aug(img)
        img = img.transpose(2,0,1)

        # 加载label文件
        if os.path.exists(label_path):
            label = []
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    l = line.strip().split(" ")
                    label.append([0, l[0], l[1], l[2], l[3], l[4]])
            label = np.array(label, dtype=np.float32)

            if label.shape[0]:
                assert label.shape[1] == 6, '> 5 label columns: %s' % label_path
                #assert (label >= 0).all(), 'negative labels: %s'%label_path
                #assert (label[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s'%label_path
        else:
            raise Exception("%s is not exist" % label_path)  
        
        return torch.from_numpy(img), torch.from_numpy(label)

    def __len__(self):
        return len(self.data_list)


if __name__ == "__main__":
    data = TensorDataset("/home/frey/python/dataset/crash-datasets/crash/train.txt")

    # img, label = data.__getitem__(0)
    # print(img.shape)
    # print(label.shape)

结果

  • 运行后发现一张图片,用找到后双击打开发现无法加载,删除之
posted @ 2022-11-04 13:19  梧桐灯下江楚滢  阅读(374)  评论(0)    收藏  举报