RNN不定长数据处理
先贴代码,之后来补解释说明
`def custom_collate(batch):
# Separate data and labels
data, labels = zip(*batch)
Pad sequences to have the same length
padded_data = pad_sequence(data, batch_first=True, padding_value=-1)
Convert labels to tensor
labels = torch.tensor(labels, dtype=torch.long)
return padded_data, labels dataloadersTrain = torch.utils.data.DataLoader(trainDatasets,
batch_size=2,
collate_fn=custom_collate,
shuffle=True,
num_workers=0)`
浙公网安备 33010602011771号