目标检测中目标数量不一致的解决方法

问题描述

  在学习使用Pytorch进行目标检测任务时,会出现和分类任务有很大区别的一点。在进行分类任务时,当指定了图像的大小,那么Dataset中每张图的张量大小都是一致的。而在目标检测任务中,在一张图上可以只有一个目标,也可以有多个目标,在Dataset中张量大小不一致会报错,例如:

RuntimeError: stack expects each tensor to be equal size, but got [1, 4] at entry 0 and [2, 4] at entry 1

在学习动手学CV-Pytorch时,发现该问题通过以下方式解决:(该部分省略与改写了部分代码)

class VOCDataset(Dataset):
    # ...
    
    def collate_fn(self):
        images = [i[0] for i in self]
        boxes = [i[1] for i in self]
        labels = [i[2] for i in self]
        difficulties = [i[3] for i in self]

        images = torch.stack(images, dim=0)

        return images, boxes, labels, difficulties
    
    # ...

train_loader = DataLoader(
        VOCDataset(),
        shuffle=True,
        batch_size=32,
        collate_fn=VOCDataset.collate_fn,
        pin_memory=True
    )

  我们看到,该代码片使用了collate_fn这个参数并调用类中相关方法解决了这个问题,我们来探究下这个参数到底有什么用。

解决方法

  我们首先探究下collate_fn这个参数,根据官方文档的描述与官方Discuss的问题,这个参数可以在接收来自__getitem()__的数据后重新规整再输出,该Discuss表明其对变长数据的处理会有非常大的帮助,我们通过例子来解释下这个参数到底在做什么。该例源自这篇文章并做了些修改。

  我们先定义一个矩阵和它所对应的label

li = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
matrix = torch.tensor([li[i:i + 3] for i in range(10)])
label = torch.tensor([li[i:i + 1] for i in range(10)])
print('matrix:', matrix)
print('label:', label)


# >>> matrix: tensor([[ 0,  1,  2],
#                     [ 1,  2,  3],
#                 	  [ 2,  3,  4],
#                 	  [ 3,  4,  5],
#                 	  [ 4,  5,  6],
#                 	  [ 5,  6,  7],
#                 	  [ 6,  7,  8],
#                 	  [ 7,  8,  9],
#                 	  [ 8,  9, 10],
#                 	  [ 9, 10, 11]])
# >>> label: tensor([[0],
#               	 [1],
#               	 [2],
#              		 [3],
#               	 [4],
#                	 [5],
#               	 [6],
#               	 [7],
#               	 [8],
#               	 [9]])

  接下来我们写一个非常简单的Dataset

class LiDataset(Dataset):
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2

    def __getitem__(self, item):
        return self.param1[item], self.param2[item]

    def __len__(self):
        return len(self.param1)

    def collect_fn(self):
        p1 = [i[0] for i in self]
        p2 = [i[1] for i in self]

        return p1, p2

  我们用DataLoader装载下这个Dataset,这也是我们在分类任务中见到的最基础的写法

print('WITH OUT collate_fn:')
dataset1 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3
)

for i in dataset1:
    print(i)
    
    
# >>> WITH OUT collate_fn:
# [
#     tensor([[0, 1, 2],
#             [1, 2, 3],
#             [2, 3, 4]]),
#     tensor([[0],
#             [1],
#             [2]])
# ]
# [
#     tensor([[3, 4, 5],
#             [4, 5, 6],
#             [5, 6, 7]]),
#     tensor([[3],
#             [4],
#             [5]])
# ]
# [
#     tensor([[6,  7,  8],
#             [7,  8,  9],
#             [8,  9, 10]]),
#     tensor([[6],
#             [7],
#             [8]])]
# [
#     tensor([[9, 10, 11]]),
#     tensor([[9]])
# ]

  这是我们最常见的输出,整个数据集被划分为多个batch,每个batch里有3条数据。虽然没有指定collate_fn,但其实这时是调用了官方默认的defaultcollate_fn并已经帮我们重组成我们现在所看到的样子,这时我们用lambda x: x定义这个参数来看看这个它原本到底是个啥样。

print('WITH lambda collate_fn:')
dataset2 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3,
    collate_fn=lambda x: x
)

for i in dataset2:
    print(i)


# >>> WITH lambda collate_fn:
# [(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
# [(tensor([3, 4, 5]), tensor([3])), (tensor([4, 5, 6]), tensor([4])), (tensor([5, 6, 7]), tensor([5]))]
# [(tensor([6, 7, 8]), tensor([6])), (tensor([7, 8, 9]), tensor([7])), (tensor([ 8,  9, 10]), tensor([8]))]
# [(tensor([ 9, 10, 11]), tensor([9]))]

  这时,我们可以清楚的看到它原本的模样了。每个batch的数据通过__getitem()__传输过来一个列表,而每个列表由return的元素组成batch_size个元组。如下所示:

[(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
  ^                  ^              ^                  ^              ^                  ^
 (matrix             label)        (matrix             label)        (matrix             label)

  也就是说,原始的数据为上面这种形式,在经过DataLoadercollate_fn后可以重组数据的输出格式。这时候我们回到开头Dataset类LiDataset中的自定义方法collate_fn,其定义如下:

def collect_fn(self):
    p1 = [i[0] for i in self]
    p2 = [i[1] for i in self]

  该方法的使用是collate_fn=LiDataset.collect_fn,完整代码见下面。根据上面所说的,来解释下这个方法。self即为每一个batch,for i in self则遍历一个batch中的所有元组,i[0]__getitem()__中return中的第一个参数,这里即为matrix,同理i[1]为label。这里将matrix整合为一个列表,label整合为一个列表,那么一个batch则为两个列表组成的元组,长度便固定为2了。

print('WITH modified collate_fn:')
dataset3 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3,
    collate_fn=LiDataset.collect_fn
)

for i in dataset3:
    print(i)
    

# >>> WITH modified collate_fn:
# ([tensor([0, 1, 2]), tensor([1, 2, 3]), tensor([2, 3, 4])], [tensor([0]), tensor([1]), tensor([2])])
# ([tensor([3, 4, 5]), tensor([4, 5, 6]), tensor([5, 6, 7])], [tensor([3]), tensor([4]), tensor([5])])
# ([tensor([6, 7, 8]), tensor([7, 8, 9]), tensor([ 8,  9, 10])], [tensor([6]), tensor([7]), tensor([8])])
# ([tensor([ 9, 10, 11])], [tensor([9])])
posted @ 2023-06-26 19:02  絵守辛玥  阅读(709)  评论(0)    收藏  举报