Loading

Understand collate_fn in PyTorch

The DataLoader is one of the most commonly used classes in PyTorch. Also, it is one of the first you learn. This class has a lot of parameters (14), but most likely, you will use about three of them (dataset, shuffle, and batch_size). Today I’d like to explain the meaning of collate_fn— which I found confusing for beginners in my experience. We will briefly explore how PyTorch creates batch and see how we can modify default behavior for our needs.

One of the first pieces of information in every deep learning course is that we perform training/inferencing in batches. Most of the time, a Batch is just a number of stacked examples. But in some cases, we would like to modify how it is created.

First things first, let’s investigate what happens in the default case. Assume we have the following toy dataset. It contains four examples, three features each.

import torch
from torch.utils.data import DataLoader
import numpy as npdata = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]])
print(data)

If we ask a loader for a batch, we will see the following (note that I set shuffle=False to eliminate randomness):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)


'''
>tensor([[ 0.1000,  7.4000,  0.0000],
         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)
'''

No surprise, but let’s formalize what was has been done:

  1. Loader selected 2items from the dataset.
  2. Those items were converted into a tensor (2 items of size 3).
  3. A new tensor was created (2x3) and returned.

Default setup also allows us to use dictionaries. Let’s see an example:

from pprint import pprint


# now dataset is a list of dicts

dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10}]


pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]


loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
 

'''
{ 
	'x1': tensor([0.1000, -0.2000],    dtype=torch.float64),
	'x2': tensor([7.4000,  5.3000],    dtype=torch.float64),
  	'y':  tensor([0, 0])
}
'''

The loader was smart enough to correctly repack data from a list of dicts. This capability is handy when your data is in JSONL format (which I personally prefer over CSV).

Custom collate function

If default collation is so smart, why might we need to create a custom one? Default one has a significant limitation — batch data must be in the same dimension. Imagine we have an NLP task, and the data is tokenized text.

# values are token indices but it does not matter 
# it can be any kind of variable-size data

nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0},
]


loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

It will not work and raise an Error:


/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     80         elem_size = len(next(it))
     81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')
     83         transposed = zip(*batch)
     84         return [default_collate(samples) for samples in transposed]
	
	RuntimeError: each element in list of batch should be of equal size

The error message says that it is impossible to create a non-rectangular tensor. BTW: See that the error is the default_collate function.

What can we do? There are two solutions:

  1. Pad the whole dataset to the longest example.
  2. Pad dynamically during batch creation.

The first solution might seem more straightforward — just expand all examples to the longest one. But there is an issue — we will waste memory and computing power (they are expensive on GPU!) for processing padding, which does not influence the result. It is especially painful if we have a few long sequences in the data, and most of them are relatively short. In such a case, we are mostly process padding instead of data!

If we pad the whole dataset to the longest sequence, there is a lot of wasted space!

An alternative is to pad the data on the fly. When samples for the batch are selected, we pad only them to the longest one. If we additionally order the data by length, the padding will be minimal. If there are a few very long sequences, they will only influence their batches- not the whole dataset.

Per-batch padding reduces the number of unnecessary spaces used.

Okay, but how to implement it? Just create a custom collate_fn. It is simple, I promise 😉

from torch.nn.utils.rnn import pad_sequence 

# 自定义collate_fn 函数
def custom_collate(data):
    
    inputs = [torch.tensor(d['tokenized_input']) for d in data] 
    labels = [d['label'] for d in data]    
    inputs = pad_sequence(inputs, batch_first=True) 
    labels = torch.tensor(labels) 

    return { 'tokenized_input': inputs, 
             'label': labels}

loader = DataLoader(nlp_data, batch_size=2, shuffle=False, collate_fn=custom_collate) 
# loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)

Step by step:

  1. For padding we use pad_sequence.
  2. Collate function takes a single argument — a list of examples. In this case, it will be a list of dicts, but it also can be a list of tuples, etc. — depending on the dataset.
  3. As data comes if format “list of dicts” we need to traverse it and create a separate list for all inputs and labels. In the meantime, tokenized_input is converted to a 1-D tensor (it was a list of ints).
  4. Perform the padding.
  5. As labels were a list of ints, we converted it into a tensor.
  6. Return formatted batch.
  7. Set our custom function in the loader.

As we can see, the batch is in the same format as for default collation with a dictionary. We clearly see that amount of padding is only minimal.

Creating a custom collation maybe is not the most common task, but you definitely need to know how to do it.

If you are using 🤗 Transformers, try to write a collator that will tokenize the data on the fly.

More content at plainenglish.io. Sign up for our free weekly newsletter. Get exclusive access to writing opportunities and advice in our community Discord.


关于Collate_Fn函数的简要说明:

data = torch.rand(5,4)
label = torch.randint(0,2, [5,1])



class MyDataset(Data):
    def __init__(self, this_data, this_label):
        super(MyDataset, self).__init__()
        self.data = this_data
        self.label = this_label
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

	
dataset = MyDataset(data, label)

one_data = dataset[0]
"""

返回的是一个tuple, 分别有self.data 和 self.label
(tensor([0.7826, 0.2459, 0.8676, 0.0740]), tensor([0]))

"""


dataloader = DataLoader(dataset, batch_size = 2, shuffle = False)
batch = next(iter(dataloader))

"""
返回的是一个列表, 分别有两个tensor
一个tensor代表的是data,  经过torch.stack()处理的结果
一个tensor返回的是label,经过torch.stack()处理的结果


[tensor([[0.7826, 0.2459, 0.8676, 0.0740],
         [0.5236, 0.6498, 0.1828, 0.8470]]),
tensor([[0],
	    [0]])]

"""




from pprint import pprint


# now dataset is a list of dicts

dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10}]


pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]


loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
 

'''

注意⚠️到原始的数据是一个字典,里面分别有x1, x2, y 三个元素
那么返回的结果也是这三个元素,
分别表示"x1":tensor,[1,2], x2 :tensor[1,2], y:tensor[1,2]
 

{ 
	'x1': tensor([0.1000, -0.2000],    dtype=torch.float64),
	'x2': tensor([7.4000,  5.3000],    dtype=torch.float64),
  	'y':  tensor([0, 0])
}
'''


# 原始的尚未进入collate_fn 的数据形状

[(tensor([0.2897, 0.4179, 0.6962, 0.6197]), tensor([0])),
 (tensor([0.2985, 0.6163, 0.2804, 0.7559]), tensor([1]))]

## 可以看到,每一个tuple就是dataset 中__getitem__函数返回的对象
## collate_fn 的作用在于,将data 和 label 堆积起来stack






def my_collate(batch):
    print("------------")
    my_data = [data[0] for data in batch]
    print("my_data: {}".format(my_data))
    my_label = [data[1] for data in batch]
    print("my_label: {}".format(my_label))
    print("------------")
    
    sam = torch.zeros(2)
    my_data = torch.stack([torch.concat([sam, data], dim=0) for data in my_data])
    my_label = torch.stack([torch.tensor(4) + label for label in my_label])
    
    batch = [my_data, my_label]
    
    return batch



  1. 该函数的作用主要是针对返回的batch数据进行处理。一般来,batch 数据是嵌套列表,也就是List[List], 其中大的列表的代表整个Batch,而Batch列表中的每一个列表均代表每一个individual_data
  2. collate_fn 函数就将这些转化原始的batch数据

PyTorch

posted @ 2022-10-09 13:24  MarkL124950  阅读(93)  评论(0)    收藏  举报