YOLOv7 源码解读之数据读取

1. YOLOv7 代码组织结构


YOLOv7 代码结构
.
├── cfg(存放`yaml`格式定义的网络结构)
│   ├── baseline(用来比较的)
│   │   ├── r50-csp.yaml
│   │   ├── x50-csp.yaml
│   │   ├── yolor-csp-x.yaml
│   │   ├── yolor-csp.yaml
│   │   ├── yolor-d6.yaml
│   │   ├── yolor-e6.yaml
│   │   ├── yolor-p6.yaml
│   │   ├── yolor-w6.yaml
│   │   ├── yolov3-spp.yaml
│   │   ├── yolov3.yaml
│   │   └── yolov4-csp.yaml
│   ├── deploy(部署时候使用的)
│   │   ├── yolov7-d6.yaml
│   │   ├── yolov7-e6e.yaml
│   │   ├── yolov7-e6.yaml
│   │   ├── yolov7-tiny-silu.yaml
│   │   ├── yolov7-tiny.yaml
│   │   ├── yolov7-w6.yaml
│   │   ├── yolov7x.yaml
│   │   └── yolov7.yaml
│   └── training(训练时候使用的)
│       ├── yolov7-d6.yaml
│       ├── yolov7-e6e.yaml
│       ├── yolov7-e6.yaml
│       ├── yolov7-tiny.yaml
│       ├── yolov7-w6.yaml
│       ├── yolov7x.yaml
│       └── yolov7.yaml
├── data()
│   ├── coco.yaml(COCO 数据集信息)
│   ├── hyp.scratch.custom.yaml(这四个都是模型训练时候的超参数)
│   ├── hyp.scratch.p5.yaml
│   ├── hyp.scratch.p6.yaml
│   └── hyp.scratch.tiny.yaml
├── deploy(部署相关)
│   └── triton-inference-server
│       ├── boundingbox.py
│       ├── client.py
│       ├── data
│       │   ├── dog.jpg
│       │   └── dog_result.jpg
│       ├── labels.py
│       ├── processing.py
│       ├── README.md
│       └── render.py
├── detect.py(可直接运行的检测脚本)
├── export.py(可导出 TorchScript、CoreML、 TorchScript-Lite和ONNX)
├── figure(一些图片)
│   ├── horses_prediction.jpg
│   ├── mask.png
│   ├── performance.png
│   ├── pose.png
│   ├── tennis_caption.png
│   ├── tennis.jpg
│   ├── tennis_panoptic.png
│   └── tennis_semantic.jpg
├── hubconf.py(感觉暂时没啥用)
├── inference
│   └── images
│       ├── bus.jpg
│       ├── horses.jpg
│       ├── image1.jpg
│       ├── image2.jpg
│       ├── image3.jpg
│       └── zidane.jpg
├── LICENSE.md
├── models(**重点**,存放网络结构)
│   ├── common.py(一些网络中的组件)
│   ├── experimental.py(一些可以实验的组件)
│   ├── __init__.py
│   └── yolo.py(网络结构定义,包括yaml 解析)
├── README.md
├── requirements.txt
├── runs(模型运行时候的输出)
│   └── train
│       ├── yolov7
│       │   ├── events.out.tfevents.1661764925.ai-ai-dev-az3-01.13526.0
│       │   ├── hyp.yaml
│       │   ├── opt.yaml
│       │   └── weights
├── scripts(获得COCO 数据集)
│   └── get_coco.sh
├── test.py(测试模型指标)
├── tools(一些jupyter notebook 代码)
│   ├── compare_YOLOv7e6_vs_YOLOv5x6_half.ipynb
│   ├── compare_YOLOv7e6_vs_YOLOv5x6.ipynb
│   ├── compare_YOLOv7_vs_YOLOv5m6_half.ipynb
│   ├── compare_YOLOv7_vs_YOLOv5m6.ipynb
│   ├── compare_YOLOv7_vs_YOLOv5s6.ipynb
│   ├── instance.ipynb
│   ├── keypoint.ipynb
│   ├── reparameterization.ipynb
│   ├── visualization.ipynb
│   ├── YOLOv7CoreML.ipynb
│   ├── YOLOv7-Dynamic-Batch-ONNXRUNTIME.ipynb
│   ├── YOLOv7-Dynamic-Batch-TENSORRT.ipynb
│   ├── YOLOv7onnx.ipynb
│   └── YOLOv7trt.ipynb
├── train_aux.py(rain p6 odels)
├── train.py(train p5 models)
├── utils
│   ├── activations.py(定义了很多激活函数)
│   ├── add_nms.py
│   ├── autoanchor.py
│   ├── aws
│   │   ├── __init__.py
│   │   ├── mime.sh
│   │   ├── resume.py
│   │   └── userdata.sh
│   ├── datasets.py(**重点**,数据的读取和加载)
│   ├── general.py(一些通用函数)
│   ├── google_app_engine
│   │   ├── additional_requirements.txt
│   │   ├── app.yaml
│   │   └── Dockerfile
│   ├── google_utils.py
│   ├── __init__.py
│   ├── loss.py(**定义损失**)
│   ├── metrics.py(衡量指标)
│   ├── plots.py(画图)
│   ├── torch_utils.py(YOLOR PyTorch utils)
│   └── wandb_logging
│       ├── __init__.py
│       ├── log_dataset.py
│       └── wandb_utils.py
└── yolov7.pt(预训练模型)


2. 数据读取

本文的目的是调试 COCO2017的数据集,之前我写过YOLOv5 训练 VOC 数据集的代码说明,https://blog.csdn.net/hymn1993/article/details/123664708
本文重新解读一下,但是没啥区别。

2.1 代码解读

程序入口:train.py

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                        hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
                                        world_size=opt.world_size, workers=opt.workers,
                                        image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))

注:augment = True,rect 为 False:

parser.add_argument('--rect', action='store_true', help='rectangular training')

我们训练时候不指定该参数,所以rect 为 False。rect: 是否开启矩形train/test,默认训练集关闭 ,验证集开启,可以加速。self.rect=True时,self.batch_shapes记载每个batch的shape(同一个batch的图片shape相同)。

utils/datasets.py::create_dataloader函数定义:

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
					  rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
	
	"""在train.py中被调用,用于生成 dataloader, dataset,testloader
	自定义dataloader函数: 调用LoadImagesAndLabels获取数据集(包括数据增强) + 调用分布式采样器 DistributedSampler +
						自定义InfiniteDataLoader 进行永久持续的采样数据
	:param path: 图片数据加载路径 train/test   如: '../COCO2017/train2017.txt'
	:param imgsz: train/test图片尺寸(数据增强后大小) 如:640
	:param batch_size: batch size 大小 如 32
	:param stride: 模型最大stride  如 32
	:param opt.single_cls: 数据集是否是单类别 默认False
	:param hyp: 超参列表dict 网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数 在命令行参数中传入 `--hyp` 来定义
	:param augment: 是否要进行数据增强  训练时为 True
	:param cache: 是否 cache_images False
	:param pad: 设置矩形训练的shape时进行的填充 默认0.0 
	:param rect: 是否开启矩形train/test  默认训练集关闭 验证集开启
	:param rank:  多卡训练时的进程编号 rank为进程编号  -1且gpu=1时不进行分布式  -1且多块gpu使用DataParallel模式  默认-1 The (global) rank of the current process. 
	:param world_size: The total number of processes. Should be equal to the total number of devices (GPU) used for distributed training.
	:param workers: dataloader的numworks 加载数据时的cpu进程数
	:param image_weights: 训练时是否根据图片样本真实框分布权重来选择图片  默认False
	:param quad: dataloader取数据时, 是否使用collate_fn4代替collate_fn  默认False
	:param prefix: 显示信息   一个标志,多为train/val,处理标签时保存cache文件会用到
	"""

	# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
	with torch_distributed_zero_first(rank):
		dataset = LoadImagesAndLabels(path, imgsz, batch_size,

	batch_size = min(batch_size, len(dataset))
	nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers])  # number of workers
	sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
	loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
	# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
	dataloader = loader(dataset,
						batch_size=batch_size,
						num_workers=nw,
						sampler=sampler,
						pin_memory=True,
						collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
	return dataloader, dataset

下面关注 utils/datasets.py::utils/datasets.py 代码

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
					  rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
	# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
	with torch_distributed_zero_first(rank):
		dataset = LoadImagesAndLabels(path, imgsz, batch_size,
					  augment=augment,  # augment images
					  hyp=hyp,  # augmentation hyperparameters
					  rect=rect,  # rectangular training
					  cache_images=cache,
					  single_cls=opt.single_cls,
					  stride=int(stride),
					  pad=pad,
					  image_weights=image_weights,
					  prefix=prefix)

下面是LoadImagesAndLabels 类的代码,是用来定义 dataset 代码:

class LoadImagesAndLabels(Dataset):  # for training/testing
	def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
				 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
		self.img_size = img_size
		self.augment = augment
		self.hyp = hyp
		self.image_weights = image_weights
		self.rect = False if image_weights else rect
		self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
		self.mosaic_border = [-img_size // 2, -img_size // 2]
		self.stride = stride
		self.path = path        
		#self.albumentations = Albumentations() if augment else None

		try:
			f = []  # image files
			for p in path if isinstance(path, list) else [path]:
				p = Path(p)  # os-agnostic # PosixPath('../COCO2017/train2017.txt')
				if p.is_dir():  # dir
					f += glob.glob(str(p / '**' / '*.*'), recursive=True)
					# f = list(p.rglob('**/*.*'))  # pathlib
				elif p.is_file():  # file 执行该步
					with open(p, 'r') as t:
						# t: ['./images/train2017/000000109622.jpg', './images/train2017/000000160694.jpg', ...]
						t = t.read().strip().splitlines()
						parent = str(p.parent) + os.sep # '../COCO2017/' 获取父目录
						f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path  每个元素是一个文件路径  将 t 每个图片路径转为 全路径
						# f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
				else:
					raise Exception(f'{prefix}{p} does not exist')
			self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) # 如果图像后缀名在定义的9种以内,则把所有的 图像后缀名小写,最后排序
			# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
			assert self.img_files, f'{prefix}No images found'
		except Exception as e:
			raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')

		# Check cache
		self.label_files = img2label_paths(self.img_files)  # labels 获取 标注文件 list,对应于 上面的 img_files 一一对应
		# cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')  # cached labels PosixPath('../COCO2017/train2017.cache') 保存一个缓存文件
		cache_path = Path('/data/hyz/datasets/COCO2017').with_suffix('.cache')
		if cache_path.is_file():
			cache, exists = torch.load(cache_path), True  # load
			#if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:  # changed
			#    cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
		else:
			cache, exists = self.cache_labels(cache_path, prefix), False  # cache

		# Display cache
		nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
		if exists:
			d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
			tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
		assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'

		# Read cache
		cache.pop('hash')  # remove hash
		cache.pop('version')  # remove version
		labels, shapes, self.segments = zip(*cache.values())
		self.labels = list(labels)
		self.shapes = np.array(shapes, dtype=np.float64)
		self.img_files = list(cache.keys())  # update
		self.label_files = img2label_paths(cache.keys())  # update
		if single_cls:
			for x in self.labels:
				x[:, 0] = 0

		n = len(shapes)  # number of images
		bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
		nb = bi[-1] + 1  # number of batches
		self.batch = bi  # batch index of image
		self.n = n
		self.indices = range(n)

		# Rectangular Training
		if self.rect:
			# Sort by aspect ratio
			s = self.shapes  # wh
			ar = s[:, 1] / s[:, 0]  # aspect ratio
			irect = ar.argsort()
			self.img_files = [self.img_files[i] for i in irect] # 图像路径list
			self.label_files = [self.label_files[i] for i in irect] # 标注路径list
			self.labels = [self.labels[i] for i in irect] # 对应的标注数据list
			self.shapes = s[irect]  # wh
			ar = ar[irect]

			# Set training image shapes
			shapes = [[1, 1]] * nb
			for i in range(nb):
				ari = ar[bi == i]
				mini, maxi = ari.min(), ari.max()
				if maxi < 1:
					shapes[i] = [maxi, 1]
				elif mini > 1:
					shapes[i] = [1, 1 / mini]

			self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

		# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
		self.imgs = [None] * n
		if cache_images:
			if cache_images == 'disk':
				self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
				self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
				self.im_cache_dir.mkdir(parents=True, exist_ok=True)
			gb = 0  # Gigabytes of cached images
			self.img_hw0, self.img_hw = [None] * n, [None] * n
			results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
			pbar = tqdm(enumerate(results), total=n)
			for i, x in pbar:
				if cache_images == 'disk':
					if not self.img_npy[i].exists():
						np.save(self.img_npy[i].as_posix(), x[0])
					gb += self.img_npy[i].stat().st_size
				else:
					self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
					gb += self.imgs[i].nbytes
				pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
			pbar.close()

	def cache_labels(self, path=Path('./labels.cache'), prefix=''):  # 重写
		# Cache dataset labels, check images and read shapes
		x = {}  # dict
		nm, nf, ne, nc = 0, 0, 0, 0  # number missing(所有图片没有标注的数目和), found(找到的标注和), empty(虽然有标注文件,但是文件内啥都没写), duplicate(读取时候出现问题的样本数目)
		pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) # 产生这么个进度条,Scanning images:   0%|                          | 0/118287 [00:00<?, ?it/s]
		for i, (im_file, lb_file) in enumerate(pbar): # 循环每个样本,图像jpg-标注txt对
			try:
				# verify images
				im = Image.open(im_file) # 验证图像是否可以打开
				im.verify()  # PIL verify  # 检查文件完整性
				shape = exif_size(im)  # 获得 image size
				segments = []  # instance segments
				assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
				assert im.format.lower() in img_formats, f'invalid image format {im.format}'

				# verify labels
				if os.path.isfile(lb_file):
					nf += 1  # label found
					with open(lb_file, 'r') as f:
						l = [x.split() for x in f.read().strip().splitlines()] # 把标注txt 文件的每行(一个标注)都读取出来组成list
						if any([len(x) > 8 for x in l]):  # is segment 如果长度大于8那么该标注是分割
							classes = np.array([x[0] for x in l], dtype=np.float32) # 标注的第一列代表类别,是一个字符串类型的数字, 如 '45', 这里组成当前文件的类别list:如 [45.0, 45.0, 50.0, 45.0, 49.0, 49.0, 49.0, 49.0]
							segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l]  # 除了第一列,后面每两个数是一个标注的坐标,把每个实例分割框的每个点坐标 reshape 下
							l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh) 如(8,5)
						l = np.array(l, dtype=np.float32)
					if len(l):
						assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh
						assert (l >= 0).all(), 'negative labels' # 所有值都  >= 0
						assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外
						assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框
					else:
						ne += 1  # label empty
						l = np.zeros((0, 5), dtype=np.float32)
				else:
					nm += 1  # label missing
					l = np.zeros((0, 5), dtype=np.float32)
				x[im_file] = [l, shape, segments] # x是一个dict,key 为 图像path,value:该图像的标注(如 8,5), 图像的宽高,分割的坐标 
			except Exception as e:
				nc += 1
				print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

			pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
						f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条
		pbar.close()

		if nf == 0:
			print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')

		x['hash'] = get_hash(self.label_files + self.img_files)
		x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目
		x['version'] = 0.1  # cache version
		torch.save(x, path)  # save for next time
		logging.info(f'{prefix}New cache created: {path}')
		return x

如果我们要用自己的格式的数据来训练 YOLOv7 那么就需要修改该部分。

那么如何修改该部分呢?下面是我针对 Interhand 数据集所做的修改。

    cache_path = Path('/datasets/' + Path(self.path).stem).with_suffix('.cache')
    if cache_path.is_file():
        cache, exists = torch.load(cache_path), True  # load
    else:
        cache, exists = self.interhand_cache_labels(cache_path, prefix), False  # cache

下面是修改的加载函数,该函数会把 label 缓存到指定目录。

def interhand_cache_labels(self, path=Path('./labels.cache'), prefix=''):  # 重写
    # Cache dataset labels, check images and read shapes
    db = COCO(self.path) 
    hand_cls = {'right': 0, 'left': 1, 'interacting': 2}
    in_img_path =  "/data/InterHand2.6M/images/InterHand2.6M_5fps_batch1/images"
    x={}
    segments = []  # instance segments
    nm, nf, ne, nc = 0, 0, 0, 0 
    pbar = tqdm(db.dataset.items(), desc='Scanning images', total=len(db.dataset.items()))
    for i, (key, value) in enumerate(pbar):
        try:
            im_file = os.path.join(in_img_path, key)
            # verify images
            im = Image.open(im_file) # 验证图像是否可以打开
            im.verify()  # PIL verify  # 检查文件完整性
            shape = exif_size(im)  # 获得 image size
            assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
            assert im.format.lower() in img_formats, f'invalid image format {im.format}'
            if value['bbox']!=[]:
                nf += 1  # label found
                classes = np.array(hand_cls[value['hand_type']], dtype=np.float32).reshape(-1, 1)
                bbox = np.array(value['bbox'], dtype=np.float32).reshape(-1, 4) # 左上,w,h

                # original_img = cv2.imread(im_file)
                # bbx = value['bbox']
                # x1 = int(bbx[0])
                # y1 = int(bbx[1])
                # x2 = int(bbx[0] + bbx[2]) 
                # y2 = int(bbx[1] + bbx[3])
                # temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下)
                # cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB))  # cv2 save

                l = np.concatenate((classes, xywh2cxcywh(bbox, shape)), 1)
                l = np.array(l, dtype=np.float32)

                # original_img = cv2.imread(im_file)
                # bbx = xywh2cxcywh(bbox, shape)
                # bbx = xywh2xyxy(bbx)
                # x1 = int(bbx[:, 0]*shape[0])
                # y1 = int(bbx[:, 1]*shape[1])
                # x2 = int(bbx[:, 2]*shape[0]) 
                # y2 = int(bbx[:, 3]*shape[1])
                # temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下)
                # cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB))  # cv2 save

                if len(l):
                    assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh
                    assert (l >= 0).all(), 'negative labels' # 所有值都  >= 0
                    assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外
                    assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框
                else:
                    ne += 1  # label empty
                    l = np.zeros((0, 5), dtype=np.float32)
            else:
                nm += 1  # label missing
                l = np.zeros((0, 5), dtype=np.float32)

            x[im_file] = [l, shape, segments]
        except Exception as e:
            nc += 1
            print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

        pbar.desc = f"{prefix}Scanning '{in_img_path}' images and labels... " \
                f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条
    pbar.close()

    if nf == 0:
        print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')

    # x['hash'] = get_hash(imgs_path_list)
    x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目
    x['version'] = 0.1  # cache version
    torch.save(x, path)  # save for next time
    logging.info(f'{prefix}New cache created: {path}')

    return x
posted @ 2022-08-30 17:02  Zenith_Hugh  阅读(4666)  评论(0)    收藏  举报