深度学习-加载数据集cv2.resize()报错
- 假如你自认为数据集制作的很完美,那么就可能是图片读取异常的问题,图片存在,但当你试图打开这张图片,就会出现图片错误或者加载异常,因此,我们需要找到这张图片并去掉它.
代码实例(在yolo-fastestv2的datasets.py)
自定义一个读取数据集的函数filte()
def filte(self):
for i in self.data_list:
img = cv2.imread(i)
if not hasattr(img, 'shape'):
print(i)
对datasets.py加以修改并运行
import os
import cv2
import random
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import Dataset
def contrast_and_brightness(img):
alpha = random.uniform(0.25, 1.75)
beta = random.uniform(0.25, 1.75)
blank = np.zeros(img.shape, img.dtype)
# dst = alpha * img + beta * blank
dst = cv2.addWeighted(img, alpha, blank, 1-alpha, beta)
return dst
def motion_blur(image):
if random.randint(1,2) == 1:
degree = random.randint(2,3)
angle = random.uniform(-360, 360)
image = np.array(image)
# 这里生成任意角度的运动模糊kernel的矩阵, degree越大,模糊程度越高
M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
motion_blur_kernel = np.diag(np.ones(degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
motion_blur_kernel = motion_blur_kernel / degree
blurred = cv2.filter2D(image, -1, motion_blur_kernel)
# convert to uint8
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
blurred = np.array(blurred, dtype=np.uint8)
return blurred
else:
return image
def augment_hsv(img, hgain = 0.0138, sgain = 0.678, vgain = 0.36):
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
dtype = img.dtype # uint8
x = np.arange(0, 256, dtype=np.int16)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
img = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR) # no return needed
return img
def random_resize(img):
h, w, _ = img.shape
rw = int(w * random.uniform(0.8, 1))
rh = int(h * random.uniform(0.8, 1))
img = cv2.resize(img, (rw, rh), interpolation = cv2.INTER_LINEAR)
img = cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR)
return img
def img_aug(img):
img = contrast_and_brightness(img)
#img = motion_blur(img)
#img = random_resize(img)
#img = augment_hsv(img)
return img
def collate_fn(batch):
img, label = zip(*batch)
for i, l in enumerate(label):
if l.shape[0] > 0:
l[:, 0] = i
return torch.stack(img), torch.cat(label, 0)
class TensorDataset():
def __init__(self, path, img_size_width = 352, img_size_height = 352, imgaug = False):
assert os.path.exists(path), "%s文件路径错误或不存在" % path
self.path = path
self.data_list = []
self.img_size_width = img_size_width
self.img_size_height = img_size_height
self.img_formats = ['bmp', 'jpg', 'jpeg', 'png']
self.imgaug = imgaug
# 数据检查
with open(self.path, 'r') as f:
for line in f.readlines():
data_path = line.strip()
if os.path.exists(data_path):
img_type = data_path.split(".")[-1]
if img_type not in self.img_formats:
raise Exception("img type error:%s" % img_type)
else:
self.data_list.append(data_path)
else:
raise Exception("%s is not exist" % data_path)
self.filte()
def filte(self):
for i in self.data_list:
img = cv2.imread(i)
if not hasattr(img, 'shape'):
print(i)
def __getitem__(self, index):
img_path = self.data_list[index]
label_path = img_path.split(".")[0] + ".txt"
# print(img_path)
# 归一化操作
img = cv2.imread(img_path)
# print(img_path)
img = cv2.resize(img, (self.img_size_width, self.img_size_height), interpolation = cv2.INTER_LINEAR)
#数据增强
if self.imgaug == True:
img = img_aug(img)
img = img.transpose(2,0,1)
# 加载label文件
if os.path.exists(label_path):
label = []
with open(label_path, 'r') as f:
for line in f.readlines():
l = line.strip().split(" ")
label.append([0, l[0], l[1], l[2], l[3], l[4]])
label = np.array(label, dtype=np.float32)
if label.shape[0]:
assert label.shape[1] == 6, '> 5 label columns: %s' % label_path
#assert (label >= 0).all(), 'negative labels: %s'%label_path
#assert (label[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s'%label_path
else:
raise Exception("%s is not exist" % label_path)
return torch.from_numpy(img), torch.from_numpy(label)
def __len__(self):
return len(self.data_list)
if __name__ == "__main__":
data = TensorDataset("/home/frey/python/dataset/crash-datasets/crash/train.txt")
# img, label = data.__getitem__(0)
# print(img.shape)
# print(label.shape)
结果
- 运行后发现一张图片,用找到后双击打开发现无法加载,删除之